Skip to main content

scirs2_stats/distributions/multivariate/
multinomial.rs

1//! Multinomial distribution functions
2//!
3//! This module provides functionality for the Multinomial distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
8// NOTE: rand, distr: weighted may not be available in current version
9// use scirs2_core::random::weighted::WeightedAliasIndex;
10use scirs2_core::random::prelude::*;
11use scirs2_core::validation::{check_probabilities, check_probabilities_sum_to_one};
12use scirs2_core::Rng;
13use std::fmt::Debug;
14
15/// Implementation of the factorial function
16#[allow(dead_code)]
17fn factorial(n: u64) -> f64 {
18    if n <= 1 {
19        return 1.0;
20    }
21
22    let mut result = 1.0;
23    for i in 2..=n {
24        result *= i as f64;
25    }
26    result
27}
28
29/// Compute the multinomial coefficient
30///
31/// (n choose n₁, n₂, ..., nₖ) = n! / (n₁! * n₂! * ... * nₖ!)
32#[allow(dead_code)]
33fn multinomial_coef(n: u64, xs: &[u64]) -> f64 {
34    let mut denominator = 1.0;
35    for &x in xs {
36        denominator *= factorial(x);
37    }
38    factorial(n) / denominator
39}
40
41/// Multinomial distribution structure
42///
43/// The multinomial distribution is a generalization of the binomial distribution.
44/// It models the probability of counts for each side of a k-sided die rolled n times.
45#[derive(Debug, Clone)]
46pub struct Multinomial {
47    /// Number of trials
48    pub n: u64,
49    /// Probability of each outcome (must sum to 1)
50    pub p: Array1<f64>,
51    // Alias sampler for efficient random sampling (temporarily disabled)
52    // alias_sampler: WeightedAliasIndex<f64>,
53}
54
55impl Multinomial {
56    /// Create a new Multinomial distribution with given parameters
57    ///
58    /// # Arguments
59    ///
60    /// * `n` - Number of trials
61    /// * `p` - Probability of each outcome (must sum to 1)
62    ///
63    /// # Returns
64    ///
65    /// * A new Multinomial distribution instance
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use scirs2_core::ndarray::array;
71    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
72    ///
73    /// // Create a multinomial distribution for a 3-sided die rolled 10 times
74    /// let n = 10;
75    /// let p = array![0.2, 0.3, 0.5]; // Probabilities for each outcome
76    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
77    /// ```
78    pub fn new<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Self>
79    where
80        D: Data<Elem = f64>,
81    {
82        let p_owned = p.to_owned();
83
84        // Validate that probabilities are non-negative and sum to 1 using core validation
85        check_probabilities(&p_owned, "Probabilities").map_err(StatsError::from)?;
86        check_probabilities_sum_to_one(&p_owned, "Probabilities", None)
87            .map_err(StatsError::from)?;
88
89        // Create alias sampler for efficient random sampling (temporarily disabled)
90        // let alias_sampler = match WeightedAliasIndex::new(p_owned.iter().cloned().collect()) {
91        //     Ok(sampler) => sampler,
92        //     Err(_) => {
93        //         return Err(StatsError::ComputationError(
94        //             "Failed to create alias sampler for random sampling".to_string(),
95        //         ))
96        //     }
97        // };
98
99        Ok(Multinomial {
100            n,
101            p: p_owned,
102            // alias_sampler,
103        })
104    }
105
106    /// Calculate the probability mass function (PMF) at a given point
107    ///
108    /// # Arguments
109    ///
110    /// * `x` - The point at which to evaluate the PMF (must be a vector of non-negative integers that sum to n)
111    ///
112    /// # Returns
113    ///
114    /// * The value of the PMF at the given point
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use scirs2_core::ndarray::array;
120    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
121    ///
122    /// let n = 10;
123    /// let p = array![0.2, 0.3, 0.5];
124    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
125    ///
126    /// // Calculate PMF at x = [2, 3, 5]
127    /// let x = array![2.0, 3.0, 5.0];
128    /// let pmf_value = multinomial.pmf(&x);
129    /// ```
130    pub fn pmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
131    where
132        D: Data<Elem = f64>,
133    {
134        let x_vec = x.to_owned();
135
136        // Check if x has the right dimension
137        if x_vec.len() != self.p.len() {
138            return 0.0;
139        }
140
141        // Convert x to u64 and check if all values are non-negative integers that sum to n
142        let mut x_u64 = Vec::with_capacity(x_vec.len());
143        let mut sum = 0;
144
145        for &val in x_vec.iter() {
146            // Check if value is a non-negative integer
147            if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
148                return 0.0;
149            }
150
151            let val_u64 = val as u64;
152            x_u64.push(val_u64);
153            sum += val_u64;
154        }
155
156        // Check if values sum to n
157        if sum != self.n {
158            return 0.0;
159        }
160
161        // Calculate the multinomial PMF:
162        // P(X = x) = n! / (x₁! * x₂! * ... * xₖ!) * p₁^x₁ * p₂^x₂ * ... * pₖ^xₖ
163
164        // Multinomial coefficient
165        let coef = multinomial_coef(self.n, &x_u64);
166
167        // Product of p_i^x_i
168        let mut product = 1.0;
169        for (i, &count) in x_u64.iter().enumerate() {
170            product *= self.p[i].powf(count as f64);
171        }
172
173        coef * product
174    }
175
176    /// Calculate the log probability mass function (log PMF) at a given point
177    ///
178    /// # Arguments
179    ///
180    /// * `x` - The point at which to evaluate the log PMF (must be a vector of non-negative integers that sum to n)
181    ///
182    /// # Returns
183    ///
184    /// * The value of the log PMF at the given point
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use scirs2_core::ndarray::array;
190    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
191    ///
192    /// let n = 10;
193    /// let p = array![0.2, 0.3, 0.5];
194    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
195    ///
196    /// // Calculate log PMF at x = [2, 3, 5]
197    /// let x = array![2.0, 3.0, 5.0];
198    /// let logpmf_value = multinomial.logpmf(&x);
199    /// ```
200    pub fn logpmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
201    where
202        D: Data<Elem = f64>,
203    {
204        let x_vec = x.to_owned();
205
206        // Check if x has the right dimension
207        if x_vec.len() != self.p.len() {
208            return f64::NEG_INFINITY;
209        }
210
211        // Convert x to u64 and check if all values are non-negative integers that sum to n
212        let mut x_u64 = Vec::with_capacity(x_vec.len());
213        let mut sum = 0;
214
215        for &val in x_vec.iter() {
216            // Check if value is a non-negative integer
217            if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
218                return f64::NEG_INFINITY;
219            }
220
221            let val_u64 = val as u64;
222            x_u64.push(val_u64);
223            sum += val_u64;
224        }
225
226        // Check if values sum to n
227        if sum != self.n {
228            return f64::NEG_INFINITY;
229        }
230
231        // Calculate the log multinomial PMF:
232        // log(P(X = x)) = log(n! / (x₁! * x₂! * ... * xₖ!)) + x₁*log(p₁) + x₂*log(p₂) + ... + xₖ*log(pₖ)
233
234        // Log of multinomial coefficient
235        let log_coef = factorial(self.n).ln();
236        let mut log_denom = 0.0;
237        for &count in &x_u64 {
238            log_denom += factorial(count).ln();
239        }
240
241        // Sum of x_i*log(p_i)
242        let mut log_prob_sum = 0.0;
243        for (i, &count) in x_u64.iter().enumerate() {
244            if count > 0 {
245                log_prob_sum += (count as f64) * self.p[i].ln();
246            }
247        }
248
249        log_coef - log_denom + log_prob_sum
250    }
251
252    /// Generate random samples from the distribution
253    ///
254    /// # Arguments
255    ///
256    /// * `size` - Number of samples to generate
257    ///
258    /// # Returns
259    ///
260    /// * Vector of random samples (each sample is a vector of counts)
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use scirs2_core::ndarray::array;
266    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
267    ///
268    /// let n = 10;
269    /// let p = array![0.2, 0.3, 0.5];
270    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
271    ///
272    /// // Generate 5 random samples
273    /// let samples = multinomial.rvs(5).expect("Operation failed");
274    /// assert_eq!(samples.len(), 5);
275    /// assert_eq!(samples[0].len(), 3);
276    /// ```
277    pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
278        let mut rng = thread_rng();
279        let mut samples = Vec::with_capacity(size);
280        let k = self.p.len();
281
282        for _ in 0..size {
283            // Initialize counts to zero
284            let mut counts = vec![0u64; k];
285
286            // Simulate n trials
287            for _ in 0..self.n {
288                // Sample category using cumulative probability
289                let u: f64 = rng.random();
290                let mut cumulative = 0.0;
291                let mut category = 0;
292                for (i, &prob) in self.p.iter().enumerate() {
293                    cumulative += prob;
294                    if u <= cumulative {
295                        category = i;
296                        break;
297                    }
298                }
299                counts[category] += 1;
300            }
301
302            // Convert to floating-point array for consistency with other distributions
303            let sample = Array1::from_iter(counts.iter().map(|&x| x as f64));
304            samples.push(sample);
305        }
306
307        Ok(samples)
308    }
309
310    /// Generate a single random sample from the distribution
311    ///
312    /// # Returns
313    ///
314    /// * A random sample (a vector of counts)
315    ///
316    /// # Examples
317    ///
318    /// ```
319    /// use scirs2_core::ndarray::array;
320    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
321    ///
322    /// let n = 10;
323    /// let p = array![0.2, 0.3, 0.5];
324    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
325    ///
326    /// // Generate a single random sample
327    /// let sample = multinomial.rvs_single().expect("Operation failed");
328    /// assert_eq!(sample.len(), 3);
329    /// ```
330    pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
331        let samples = self.rvs(1)?;
332        Ok(samples[0].clone())
333    }
334
335    /// Calculate the mean of the distribution
336    ///
337    /// # Returns
338    ///
339    /// * Mean vector (n * p)
340    ///
341    /// # Examples
342    ///
343    /// ```
344    /// use scirs2_core::ndarray::array;
345    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
346    ///
347    /// let n = 10;
348    /// let p = array![0.2, 0.3, 0.5];
349    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
350    ///
351    /// let mean = multinomial.mean();
352    /// // Mean should be [2.0, 3.0, 5.0]
353    /// ```
354    pub fn mean(&self) -> Array1<f64> {
355        let n_f64 = self.n as f64;
356        self.p.mapv(|p_i| n_f64 * p_i)
357    }
358
359    /// Calculate the covariance matrix of the distribution
360    ///
361    /// # Returns
362    ///
363    /// * Covariance matrix
364    ///
365    /// # Examples
366    ///
367    /// ```
368    /// use scirs2_core::ndarray::array;
369    /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
370    ///
371    /// let n = 10;
372    /// let p = array![0.2, 0.3, 0.5];
373    /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
374    ///
375    /// let cov = multinomial.cov();
376    /// ```
377    pub fn cov(&self) -> scirs2_core::ndarray::Array2<f64> {
378        let k = self.p.len();
379        let n_f64 = self.n as f64;
380        let mut cov = scirs2_core::ndarray::Array2::zeros((k, k));
381
382        // Fill the covariance matrix
383        // Diagonal: n*p_i*(1-p_i)
384        // Off-diagonal: -n*p_i*p_j
385        for i in 0..k {
386            for j in 0..k {
387                if i == j {
388                    cov[[i, j]] = n_f64 * self.p[i] * (1.0 - self.p[i]);
389                } else {
390                    cov[[i, j]] = -n_f64 * self.p[i] * self.p[j];
391                }
392            }
393        }
394
395        cov
396    }
397}
398
399/// Create a Multinomial distribution with the given parameters.
400///
401/// This is a convenience function to create a Multinomial distribution with
402/// the given number of trials and probability vector.
403///
404/// # Arguments
405///
406/// * `n` - Number of trials
407/// * `p` - Probability of each outcome (must sum to 1)
408///
409/// # Returns
410///
411/// * A Multinomial distribution object
412///
413/// # Examples
414///
415/// ```
416/// use scirs2_core::ndarray::array;
417/// use scirs2_stats::distributions::multivariate;
418///
419/// let n = 10;
420/// let p = array![0.2, 0.3, 0.5]; // Probabilities for each outcome
421/// let multinomial = multivariate::multinomial(n, p).expect("Operation failed");
422/// ```
423#[allow(dead_code)]
424pub fn multinomial<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Multinomial>
425where
426    D: Data<Elem = f64>,
427{
428    Multinomial::new(n, p)
429}
430
431/// Implementation of SampleableDistribution for Multinomial
432impl SampleableDistribution<Array1<f64>> for Multinomial {
433    fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
434        self.rvs(size)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use approx::assert_relative_eq;
442    use scirs2_core::ndarray::array;
443
444    #[test]
445    fn test_multinomial_creation() {
446        // Valid multinomial
447        let n = 10;
448        let p = array![0.2, 0.3, 0.5];
449        let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
450        assert_eq!(multinomial.n, n);
451        assert_eq!(multinomial.p, p);
452
453        // Invalid probabilities (don't sum to 1)
454        let p_invalid_sum = array![0.2, 0.3, 0.6]; // sum = 1.1
455        assert!(Multinomial::new(n, p_invalid_sum).is_err());
456
457        // Invalid probabilities (negative values)
458        let p_negative = array![0.2, -0.1, 0.9];
459        assert!(Multinomial::new(n, p_negative).is_err());
460    }
461
462    #[test]
463    fn test_multinomial_pmf() {
464        let n = 5;
465        let p = array![0.5, 0.5];
466        let multinomial = Multinomial::new(n, p).expect("Operation failed");
467
468        // PMF at x = [2, 3]
469        let x1 = array![2.0, 3.0];
470        let pmf1 = multinomial.pmf(&x1);
471
472        // Calculate expected PMF: 5!/(2!*3!) * 0.5^2 * 0.5^3 = 10 * 0.25 * 0.125 = 0.3125
473        let expected_pmf1 = 0.3125;
474        assert_relative_eq!(pmf1, expected_pmf1, epsilon = 1e-10);
475
476        // PMF at x = [5, 0]
477        let x2 = array![5.0, 0.0];
478        let pmf2 = multinomial.pmf(&x2);
479
480        // Calculate expected PMF: 5!/(5!*0!) * 0.5^5 * 0.5^0 = 1 * 0.03125 * 1 = 0.03125
481        let expected_pmf2 = 0.03125;
482        assert_relative_eq!(pmf2, expected_pmf2, epsilon = 1e-10);
483
484        // PMF at invalid x (doesn't sum to n)
485        let x_invalid = array![2.0, 2.0]; // sum = 4 != 5
486        let pmf_invalid = multinomial.pmf(&x_invalid);
487        assert_eq!(pmf_invalid, 0.0);
488
489        // PMF at invalid x (non-integer values)
490        let x_non_int = array![2.5, 2.5];
491        let pmf_non_int = multinomial.pmf(&x_non_int);
492        assert_eq!(pmf_non_int, 0.0);
493
494        // PMF at invalid x (wrong dimension)
495        let x_wrong_dim = array![2.0, 3.0, 0.0];
496        let pmf_wrong_dim = multinomial.pmf(&x_wrong_dim);
497        assert_eq!(pmf_wrong_dim, 0.0);
498    }
499
500    #[test]
501    fn test_multinomial_logpmf() {
502        let n = 5;
503        let p = array![0.5, 0.5];
504        let multinomial = Multinomial::new(n, p).expect("Operation failed");
505
506        // LogPMF at x = [2, 3]
507        let x1 = array![2.0, 3.0];
508        let logpmf1 = multinomial.logpmf(&x1);
509        let pmf1 = multinomial.pmf(&x1);
510
511        // Check that exp(logPMF) = PMF
512        assert_relative_eq!(logpmf1.exp(), pmf1, epsilon = 1e-10);
513
514        // LogPMF at invalid x (doesn't sum to n)
515        let x_invalid = array![2.0, 2.0]; // sum = 4 != 5
516        let logpmf_invalid = multinomial.logpmf(&x_invalid);
517        assert_eq!(logpmf_invalid, f64::NEG_INFINITY);
518    }
519
520    #[test]
521    fn test_multinomial_mean() {
522        let n = 10;
523        let p = array![0.2, 0.3, 0.5];
524        let multinomial = Multinomial::new(n, p).expect("Operation failed");
525
526        let mean = multinomial.mean();
527        let expected_mean = array![2.0, 3.0, 5.0];
528
529        for i in 0..3 {
530            assert_relative_eq!(mean[i], expected_mean[i], epsilon = 1e-10);
531        }
532    }
533
534    #[test]
535    fn test_multinomial_cov() {
536        let n = 10;
537        let p = array![0.2, 0.3, 0.5];
538        let multinomial = Multinomial::new(n, p).expect("Operation failed");
539
540        let cov = multinomial.cov();
541
542        // Expected covariance matrix:
543        // [n*p1*(1-p1), -n*p1*p2, -n*p1*p3]
544        // [-n*p2*p1, n*p2*(1-p2), -n*p2*p3]
545        // [-n*p3*p1, -n*p3*p2, n*p3*(1-p3)]
546
547        // Diagonal elements
548        assert_relative_eq!(cov[[0, 0]], 10.0 * 0.2 * 0.8, epsilon = 1e-10); // 1.6
549        assert_relative_eq!(cov[[1, 1]], 10.0 * 0.3 * 0.7, epsilon = 1e-10); // 2.1
550        assert_relative_eq!(cov[[2, 2]], 10.0 * 0.5 * 0.5, epsilon = 1e-10); // 2.5
551
552        // Off-diagonal elements
553        assert_relative_eq!(cov[[0, 1]], -10.0 * 0.2 * 0.3, epsilon = 1e-10); // -0.6
554        assert_relative_eq!(cov[[0, 2]], -10.0 * 0.2 * 0.5, epsilon = 1e-10); // -1.0
555        assert_relative_eq!(cov[[1, 2]], -10.0 * 0.3 * 0.5, epsilon = 1e-10); // -1.5
556
557        // Symmetry
558        assert_relative_eq!(cov[[1, 0]], cov[[0, 1]], epsilon = 1e-10);
559        assert_relative_eq!(cov[[2, 0]], cov[[0, 2]], epsilon = 1e-10);
560        assert_relative_eq!(cov[[2, 1]], cov[[1, 2]], epsilon = 1e-10);
561    }
562
563    #[test]
564    fn test_multinomial_rvs() {
565        let n = 100;
566        let p = array![0.2, 0.3, 0.5];
567        let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
568
569        // Generate samples
570        let num_samples = 100;
571        let samples = multinomial.rvs(num_samples).expect("Operation failed");
572
573        // Check the number of samples
574        assert_eq!(samples.len(), num_samples);
575
576        // Check the dimension of each sample
577        for sample in &samples {
578            assert_eq!(sample.len(), 3);
579
580            // Check that each sample sums to n
581            let sum: f64 = sample.sum();
582            assert_eq!(sum, n as f64);
583        }
584
585        // Calculate sample means
586        let mut sample_sum = array![0.0, 0.0, 0.0];
587        for sample in &samples {
588            sample_sum += sample;
589        }
590        let sample_mean = sample_sum / num_samples as f64;
591
592        // Expected means
593        let expected_mean = array![20.0, 30.0, 50.0];
594
595        // Check that sample means are reasonably close to expected means
596        // (using larger tolerance due to random sampling)
597        for i in 0..3 {
598            assert!((sample_mean[i] - expected_mean[i]).abs() < 5.0);
599        }
600    }
601
602    #[test]
603    fn test_multinomial_rvs_single() {
604        let n = 10;
605        let p = array![0.2, 0.3, 0.5];
606        let multinomial = Multinomial::new(n, p).expect("Operation failed");
607
608        let sample = multinomial.rvs_single().expect("Operation failed");
609
610        // Check the dimension of the sample
611        assert_eq!(sample.len(), 3);
612
613        // Check that the sample sums to n
614        let sum: f64 = sample.sum();
615        assert_eq!(sum, n as f64);
616    }
617
618    #[test]
619    fn test_multinomial_coef() {
620        // (5 choose 2,3) = 5! / (2! * 3!)
621        let coef1 = multinomial_coef(5, &[2, 3]);
622        assert_eq!(coef1, 10.0);
623
624        // (8 choose 3,2,3) = 8! / (3! * 2! * 3!)
625        let coef2 = multinomial_coef(8, &[3, 2, 3]);
626        assert_eq!(coef2, 560.0);
627    }
628}