scirs2_stats/distributions/multivariate/dirichlet.rs
1//! Dirichlet distribution functions
2//!
3//! This module provides functionality for the Dirichlet distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::{Distribution, Gamma as RandGamma};
10use std::fmt::Debug;
11
12/// Implementation of the natural logarithm of the gamma function
13///
14/// This is a workaround for the unstable gamma function in Rust
15#[allow(dead_code)]
16fn lgamma(x: f64) -> f64 {
17 if x <= 0.0 {
18 panic!("lgamma requires positive input");
19 }
20
21 // For integers, we can use a simpler calculation
22 if x.fract() == 0.0 && x <= 20.0 {
23 let n = x as usize;
24 if n == 1 || n == 2 {
25 return 0.0; // ln(1) = 0
26 }
27
28 let mut result = 0.0;
29 for i in 2..n {
30 result += (i as f64).ln();
31 }
32 return result;
33 }
34
35 // For x = 0.5, we have Γ(0.5) = sqrt(π)
36 if (x - 0.5).abs() < 1e-10 {
37 return (std::f64::consts::PI.sqrt()).ln();
38 }
39
40 // For x > 1, use the recurrence relation: Γ(x+1) = x * Γ(x)
41 if x > 1.0 {
42 return (x - 1.0).ln() + lgamma(x - 1.0);
43 }
44
45 // For 0 < x < 1, use the reflection formula: Γ(x) * Γ(1-x) = π/sin(πx)
46 if x < 1.0 {
47 return (std::f64::consts::PI / (std::f64::consts::PI * x).sin()).ln() - lgamma(1.0 - x);
48 }
49
50 // Lanczos approximation for other values around 1
51 let p = [
52 676.5203681218851,
53 -1259.1392167224028,
54 771.323_428_777_653_1,
55 -176.615_029_162_140_6,
56 12.507343278686905,
57 -0.13857109526572012,
58 9.984_369_578_019_572e-6,
59 1.5056327351493116e-7,
60 ];
61
62 let x_adj = x - 1.0;
63 let t = x_adj + 7.5;
64
65 let mut sum = 0.0;
66 for (i, &coef) in p.iter().enumerate() {
67 sum += coef / (x_adj + (i + 1) as f64);
68 }
69
70 let pi = std::f64::consts::PI;
71 let sqrt_2pi = (2.0 * pi).sqrt();
72
73 sqrt_2pi.ln() + sum.ln() + (x_adj + 0.5) * t.ln() - t
74}
75
76/// Dirichlet distribution structure
77#[derive(Debug, Clone)]
78pub struct Dirichlet {
79 /// Concentration parameters (alpha values)
80 pub alpha: Array1<f64>,
81 /// Dimension of the distribution (number of categories)
82 pub dim: usize,
83 /// Natural log of the normalization constant (cached for efficiency)
84 log_norm_const: f64,
85}
86
87impl Dirichlet {
88 /// Create a new Dirichlet distribution with given concentration parameters
89 ///
90 /// # Arguments
91 ///
92 /// * `alpha` - Concentration parameters (all values must be positive)
93 ///
94 /// # Returns
95 ///
96 /// * A new Dirichlet distribution instance
97 ///
98 /// # Examples
99 ///
100 /// ```
101 /// use scirs2_core::ndarray::array;
102 /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
103 ///
104 /// // Create a 3D Dirichlet distribution with symmetric parameters (equivalent to a uniform distribution over the simplex)
105 /// let alpha = array![1.0, 1.0, 1.0];
106 /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
107 /// ```
108 pub fn new<D>(alpha: ArrayBase<D, Ix1>) -> StatsResult<Self>
109 where
110 D: Data<Elem = f64>,
111 {
112 let alpha_owned = alpha.to_owned();
113 let dim = alpha_owned.len();
114
115 // Check that all _alpha values are positive
116 for &a in alpha_owned.iter() {
117 if a <= 0.0 {
118 return Err(StatsError::DomainError(
119 "All concentration parameters must be positive".to_string(),
120 ));
121 }
122 }
123
124 let alpha_sum = alpha_owned.sum();
125
126 // Compute the log normalization constant:
127 // ln[B(α)] = sum(ln[Γ(αᵢ)]) - ln[Γ(sum(αᵢ))]
128 let mut log_norm_const = 0.0;
129
130 // Sum of log(Gamma(alpha_i))
131 for &a in alpha_owned.iter() {
132 log_norm_const += lgamma(a);
133 }
134
135 // Subtract log(Gamma(sum(alpha_i)))
136 log_norm_const -= lgamma(alpha_sum);
137
138 Ok(Dirichlet {
139 alpha: alpha_owned,
140 dim,
141 log_norm_const,
142 })
143 }
144
145 /// Calculate the probability density function (PDF) at a given point
146 ///
147 /// # Arguments
148 ///
149 /// * `x` - The point at which to evaluate the PDF (must sum to 1)
150 ///
151 /// # Returns
152 ///
153 /// * The value of the PDF at the given point
154 ///
155 /// # Examples
156 ///
157 /// ```
158 /// use scirs2_core::ndarray::array;
159 /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
160 ///
161 /// let alpha = array![1.0, 1.0, 1.0];
162 /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
163 ///
164 /// // PDF for a uniform Dirichlet at any point on the simplex is 2 (in 3D)
165 /// let point = array![0.3, 0.3, 0.4];
166 /// let pdf_value = dirichlet.pdf(&point);
167 /// assert!((pdf_value - 2.0).abs() < 1e-10);
168 /// ```
169 pub fn pdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
170 where
171 D: Data<Elem = f64>,
172 {
173 if x.len() != self.dim {
174 return 0.0; // Return zero for invalid dimensions
175 }
176
177 // Check if x is on the simplex (all values > 0 and sum to 1)
178 let sum: f64 = x.iter().sum();
179 if (sum - 1.0).abs() > 1e-10 {
180 return 0.0; // Point not on the simplex
181 }
182
183 for &val in x.iter() {
184 if val <= 0.0 || val >= 1.0 {
185 return 0.0; // Values must be in (0, 1)
186 }
187 }
188
189 // Calculate the PDF using the formula:
190 // p(x|α) = [∏ xᵢ^(αᵢ-1)] / B(α)
191 // where B(α) is the multivariate beta function
192
193 // We'll work in log space for numerical stability
194 let log_pdf = self.logpdf(x);
195 log_pdf.exp()
196 }
197
198 /// Calculate the log probability density function (log PDF) at a given point
199 ///
200 /// # Arguments
201 ///
202 /// * `x` - The point at which to evaluate the log PDF (must sum to 1)
203 ///
204 /// # Returns
205 ///
206 /// * The value of the log PDF at the given point
207 ///
208 /// # Examples
209 ///
210 /// ```
211 /// use scirs2_core::ndarray::array;
212 /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
213 ///
214 /// let alpha = array![1.0, 1.0, 1.0];
215 /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
216 ///
217 /// let point = array![0.3, 0.3, 0.4];
218 /// let logpdf_value = dirichlet.logpdf(&point);
219 /// assert!((logpdf_value - 0.693).abs() < 1e-3); // ln(2) ≈ 0.693
220 /// ```
221 pub fn logpdf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
222 where
223 D: Data<Elem = f64>,
224 {
225 if x.len() != self.dim {
226 return f64::NEG_INFINITY; // Return -∞ for invalid dimensions
227 }
228
229 // Check if x is on the simplex (all values > 0 and sum to 1)
230 let sum: f64 = x.iter().sum();
231 if (sum - 1.0).abs() > 1e-10 {
232 return f64::NEG_INFINITY; // Point not on the simplex
233 }
234
235 for &val in x.iter() {
236 if val <= 0.0 || val >= 1.0 {
237 return f64::NEG_INFINITY; // Values must be in (0, 1)
238 }
239 }
240
241 // Calculate the log PDF using the formula:
242 // log p(x|α) = sum[(αᵢ-1)log(xᵢ)] - log B(α)
243 let mut log_pdf = -self.log_norm_const;
244
245 for i in 0..self.dim {
246 log_pdf += (self.alpha[i] - 1.0) * x[i].ln();
247 }
248
249 log_pdf
250 }
251
252 /// Generate random samples from the distribution
253 ///
254 /// # Arguments
255 ///
256 /// * `size` - Number of samples to generate
257 ///
258 /// # Returns
259 ///
260 /// * Matrix where each row is a random sample
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use scirs2_core::ndarray::array;
266 /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
267 ///
268 /// let alpha = array![1.0, 2.0, 3.0];
269 /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
270 ///
271 /// let samples = dirichlet.rvs(10).expect("Operation failed");
272 /// assert_eq!(samples.len(), 10);
273 /// assert_eq!(samples[0].len(), 3);
274 /// ```
275 pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
276 let mut rng = thread_rng();
277 let mut samples = Vec::with_capacity(size);
278
279 // Generate samples using the gamma method:
280 // 1. Generate independent gamma samples with shape αᵢ and scale=1
281 // 2. Normalize by their sum
282
283 for _ in 0..size {
284 let mut sample = Array1::<f64>::zeros(self.dim);
285 let mut sum = 0.0;
286
287 // Generate gamma samples
288 for i in 0..self.dim {
289 let gamma_dist = RandGamma::new(self.alpha[i], 1.0).map_err(|_| {
290 StatsError::ComputationError("Failed to create gamma distribution".to_string())
291 })?;
292
293 let gamma_sample = gamma_dist.sample(&mut rng);
294 sample[i] = gamma_sample;
295 sum += gamma_sample;
296 }
297
298 // Normalize to get a point on the simplex
299 sample.mapv_inplace(|x| x / sum);
300 samples.push(sample);
301 }
302
303 Ok(samples)
304 }
305
306 /// Generate a single random sample from the distribution
307 ///
308 /// # Returns
309 ///
310 /// * Vector representing a single sample
311 ///
312 /// # Examples
313 ///
314 /// ```
315 /// use scirs2_core::ndarray::array;
316 /// use scirs2_stats::distributions::multivariate::dirichlet::Dirichlet;
317 ///
318 /// let alpha = array![1.0, 2.0, 3.0];
319 /// let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
320 ///
321 /// let sample = dirichlet.rvs_single().expect("Operation failed");
322 /// assert_eq!(sample.len(), 3);
323 /// ```
324 pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
325 let samples = self.rvs(1)?;
326 Ok(samples[0].clone())
327 }
328}
329
330/// Create a Dirichlet distribution with the given parameters.
331///
332/// This is a convenience function to create a Dirichlet distribution with
333/// the given concentration parameters.
334///
335/// # Arguments
336///
337/// * `alpha` - Concentration parameters (all values must be positive)
338///
339/// # Returns
340///
341/// * A Dirichlet distribution object
342///
343/// # Examples
344///
345/// ```
346/// use scirs2_core::ndarray::array;
347/// use scirs2_stats::distributions::multivariate;
348///
349/// let alpha = array![1.0, 1.0, 1.0];
350/// let dirichlet = multivariate::dirichlet(&alpha).expect("Operation failed");
351/// let point = array![0.3, 0.3, 0.4];
352/// let pdf_at_point = dirichlet.pdf(&point);
353/// ```
354#[allow(dead_code)]
355pub fn dirichlet<D>(alpha: &ArrayBase<D, Ix1>) -> StatsResult<Dirichlet>
356where
357 D: Data<Elem = f64>,
358{
359 Dirichlet::new(alpha.to_owned())
360}
361
362/// Implementation of SampleableDistribution for Dirichlet
363impl SampleableDistribution<Array1<f64>> for Dirichlet {
364 fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
365 self.rvs(size)
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use approx::assert_relative_eq;
373 use scirs2_core::ndarray::array;
374
375 #[test]
376 fn test_dirichlet_creation() {
377 // Uniform Dirichlet
378 let alpha = array![1.0, 1.0, 1.0];
379 let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
380
381 assert_eq!(dirichlet.dim, 3);
382 assert_eq!(dirichlet.alpha, alpha);
383
384 // Non-uniform Dirichlet
385 let alpha2 = array![2.0, 3.0, 4.0];
386 let dirichlet2 = Dirichlet::new(alpha2.clone()).expect("Operation failed");
387
388 assert_eq!(dirichlet2.dim, 3);
389 assert_eq!(dirichlet2.alpha, alpha2);
390 }
391
392 #[test]
393 fn test_dirichlet_creation_errors() {
394 // Zero alpha value
395 let alpha = array![1.0, 0.0, 1.0];
396 assert!(Dirichlet::new(alpha).is_err());
397
398 // Negative alpha value
399 let alpha = array![1.0, -1.0, 1.0];
400 assert!(Dirichlet::new(alpha).is_err());
401 }
402
403 #[test]
404 fn test_dirichlet_pdf() {
405 // Uniform Dirichlet (alpha = [1,1,1])
406 // PDF value should be constant on the simplex: 2 for 3D
407 let alpha = array![1.0, 1.0, 1.0];
408 let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
409
410 let point1 = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
411 let point2 = array![0.2, 0.3, 0.5];
412
413 assert_relative_eq!(dirichlet.pdf(&point1), 2.0, epsilon = 1e-10);
414 assert_relative_eq!(dirichlet.pdf(&point2), 2.0, epsilon = 1e-10);
415
416 // Concentrated Dirichlet
417 let alpha = array![5.0, 5.0, 5.0];
418 let concentrated = Dirichlet::new(alpha).expect("Operation failed");
419
420 // PDF should be higher at the center than at the edges
421 let center = array![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
422 let edge = array![0.01, 0.01, 0.98];
423
424 assert!(concentrated.pdf(¢er) > concentrated.pdf(&edge));
425 }
426
427 #[test]
428 fn test_dirichlet_pdf_edge_cases() {
429 let alpha = array![1.0, 1.0, 1.0];
430 let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
431
432 // Points not on the simplex
433 let invalid1 = array![0.3, 0.3, 0.3]; // Sum != 1
434 let invalid2 = array![0.5, 0.6, 0.2]; // Sum > 1
435 let invalid3 = array![0.0, 0.5, 0.5]; // Contains 0
436 let invalid4 = array![1.0, 0.0, 0.0]; // Contains 0
437
438 assert_eq!(dirichlet.pdf(&invalid1), 0.0);
439 assert_eq!(dirichlet.pdf(&invalid2), 0.0);
440 assert_eq!(dirichlet.pdf(&invalid3), 0.0);
441 assert_eq!(dirichlet.pdf(&invalid4), 0.0);
442 }
443
444 #[test]
445 fn test_dirichlet_logpdf() {
446 let alpha = array![1.0, 1.0, 1.0];
447 let dirichlet = Dirichlet::new(alpha).expect("Operation failed");
448
449 let point = array![0.3, 0.3, 0.4];
450
451 // Log of uniform Dirichlet with alpha=[1,1,1] is ln(2) ≈ 0.693
452 assert_relative_eq!(dirichlet.logpdf(&point), 0.693, epsilon = 1e-3);
453
454 // Check that exp(logPDF) = PDF
455 assert_relative_eq!(
456 dirichlet.logpdf(&point).exp(),
457 dirichlet.pdf(&point),
458 epsilon = 1e-10
459 );
460 }
461
462 #[test]
463 fn test_dirichlet_rvs() {
464 let alpha = array![1.0, 2.0, 3.0];
465 let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
466
467 // Generate samples
468 let n_samples_ = 1000;
469 let samples = dirichlet.rvs(n_samples_).expect("Operation failed");
470
471 // Check number of samples
472 assert_eq!(samples.len(), n_samples_);
473
474 // Check that all samples sum to 1 (within floating point error)
475 for sample in &samples {
476 let sum: f64 = sample.iter().sum();
477 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
478
479 // Check all values are in [0,1]
480 for &val in sample.iter() {
481 assert!(val >= 0.0 && val <= 1.0);
482 }
483 }
484
485 // Check sample mean is close to expected mean: E[X_i] = alpha_i / sum(alpha)
486 let mut sample_mean = [0.0; 3];
487 for sample in &samples {
488 for i in 0..3 {
489 sample_mean[i] += sample[i];
490 }
491 }
492
493 let alpha_sum = alpha.sum();
494 for i in 0..3 {
495 sample_mean[i] /= n_samples_ as f64;
496 let expected_mean = alpha[i] / alpha_sum;
497 assert_relative_eq!(sample_mean[i], expected_mean, epsilon = 0.05);
498 }
499 }
500
501 #[test]
502 fn test_dirichlet_rvs_single() {
503 let alpha = array![1.0, 2.0, 3.0];
504 let dirichlet = Dirichlet::new(alpha.clone()).expect("Operation failed");
505
506 let sample = dirichlet.rvs_single().expect("Operation failed");
507
508 // Check sample dimension
509 assert_eq!(sample.len(), 3);
510
511 // Check sample sums to 1
512 let sum: f64 = sample.iter().sum();
513 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
514
515 // Check all values in [0,1]
516 for &val in sample.iter() {
517 assert!(val >= 0.0 && val <= 1.0);
518 }
519 }
520}