scirs2_stats/distributions/multivariate/multinomial.rs
1//! Multinomial distribution functions
2//!
3//! This module provides functionality for the Multinomial distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
8// NOTE: rand, distr: weighted may not be available in current version
9// use scirs2_core::random::weighted::WeightedAliasIndex;
10use scirs2_core::random::prelude::*;
11use scirs2_core::validation::{check_probabilities, check_probabilities_sum_to_one};
12use scirs2_core::Rng;
13use std::fmt::Debug;
14
15/// Implementation of the factorial function
16#[allow(dead_code)]
17fn factorial(n: u64) -> f64 {
18 if n <= 1 {
19 return 1.0;
20 }
21
22 let mut result = 1.0;
23 for i in 2..=n {
24 result *= i as f64;
25 }
26 result
27}
28
29/// Compute the multinomial coefficient
30///
31/// (n choose n₁, n₂, ..., nₖ) = n! / (n₁! * n₂! * ... * nₖ!)
32#[allow(dead_code)]
33fn multinomial_coef(n: u64, xs: &[u64]) -> f64 {
34 let mut denominator = 1.0;
35 for &x in xs {
36 denominator *= factorial(x);
37 }
38 factorial(n) / denominator
39}
40
41/// Multinomial distribution structure
42///
43/// The multinomial distribution is a generalization of the binomial distribution.
44/// It models the probability of counts for each side of a k-sided die rolled n times.
45#[derive(Debug, Clone)]
46pub struct Multinomial {
47 /// Number of trials
48 pub n: u64,
49 /// Probability of each outcome (must sum to 1)
50 pub p: Array1<f64>,
51 // Alias sampler for efficient random sampling (temporarily disabled)
52 // alias_sampler: WeightedAliasIndex<f64>,
53}
54
55impl Multinomial {
56 /// Create a new Multinomial distribution with given parameters
57 ///
58 /// # Arguments
59 ///
60 /// * `n` - Number of trials
61 /// * `p` - Probability of each outcome (must sum to 1)
62 ///
63 /// # Returns
64 ///
65 /// * A new Multinomial distribution instance
66 ///
67 /// # Examples
68 ///
69 /// ```
70 /// use scirs2_core::ndarray::array;
71 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
72 ///
73 /// // Create a multinomial distribution for a 3-sided die rolled 10 times
74 /// let n = 10;
75 /// let p = array![0.2, 0.3, 0.5]; // Probabilities for each outcome
76 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
77 /// ```
78 pub fn new<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Self>
79 where
80 D: Data<Elem = f64>,
81 {
82 let p_owned = p.to_owned();
83
84 // Validate that probabilities are non-negative and sum to 1 using core validation
85 check_probabilities(&p_owned, "Probabilities").map_err(StatsError::from)?;
86 check_probabilities_sum_to_one(&p_owned, "Probabilities", None)
87 .map_err(StatsError::from)?;
88
89 // Create alias sampler for efficient random sampling (temporarily disabled)
90 // let alias_sampler = match WeightedAliasIndex::new(p_owned.iter().cloned().collect()) {
91 // Ok(sampler) => sampler,
92 // Err(_) => {
93 // return Err(StatsError::ComputationError(
94 // "Failed to create alias sampler for random sampling".to_string(),
95 // ))
96 // }
97 // };
98
99 Ok(Multinomial {
100 n,
101 p: p_owned,
102 // alias_sampler,
103 })
104 }
105
106 /// Calculate the probability mass function (PMF) at a given point
107 ///
108 /// # Arguments
109 ///
110 /// * `x` - The point at which to evaluate the PMF (must be a vector of non-negative integers that sum to n)
111 ///
112 /// # Returns
113 ///
114 /// * The value of the PMF at the given point
115 ///
116 /// # Examples
117 ///
118 /// ```
119 /// use scirs2_core::ndarray::array;
120 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
121 ///
122 /// let n = 10;
123 /// let p = array![0.2, 0.3, 0.5];
124 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
125 ///
126 /// // Calculate PMF at x = [2, 3, 5]
127 /// let x = array![2.0, 3.0, 5.0];
128 /// let pmf_value = multinomial.pmf(&x);
129 /// ```
130 pub fn pmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
131 where
132 D: Data<Elem = f64>,
133 {
134 let x_vec = x.to_owned();
135
136 // Check if x has the right dimension
137 if x_vec.len() != self.p.len() {
138 return 0.0;
139 }
140
141 // Convert x to u64 and check if all values are non-negative integers that sum to n
142 let mut x_u64 = Vec::with_capacity(x_vec.len());
143 let mut sum = 0;
144
145 for &val in x_vec.iter() {
146 // Check if value is a non-negative integer
147 if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
148 return 0.0;
149 }
150
151 let val_u64 = val as u64;
152 x_u64.push(val_u64);
153 sum += val_u64;
154 }
155
156 // Check if values sum to n
157 if sum != self.n {
158 return 0.0;
159 }
160
161 // Calculate the multinomial PMF:
162 // P(X = x) = n! / (x₁! * x₂! * ... * xₖ!) * p₁^x₁ * p₂^x₂ * ... * pₖ^xₖ
163
164 // Multinomial coefficient
165 let coef = multinomial_coef(self.n, &x_u64);
166
167 // Product of p_i^x_i
168 let mut product = 1.0;
169 for (i, &count) in x_u64.iter().enumerate() {
170 product *= self.p[i].powf(count as f64);
171 }
172
173 coef * product
174 }
175
176 /// Calculate the log probability mass function (log PMF) at a given point
177 ///
178 /// # Arguments
179 ///
180 /// * `x` - The point at which to evaluate the log PMF (must be a vector of non-negative integers that sum to n)
181 ///
182 /// # Returns
183 ///
184 /// * The value of the log PMF at the given point
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// use scirs2_core::ndarray::array;
190 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
191 ///
192 /// let n = 10;
193 /// let p = array![0.2, 0.3, 0.5];
194 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
195 ///
196 /// // Calculate log PMF at x = [2, 3, 5]
197 /// let x = array![2.0, 3.0, 5.0];
198 /// let logpmf_value = multinomial.logpmf(&x);
199 /// ```
200 pub fn logpmf<D>(&self, x: &ArrayBase<D, Ix1>) -> f64
201 where
202 D: Data<Elem = f64>,
203 {
204 let x_vec = x.to_owned();
205
206 // Check if x has the right dimension
207 if x_vec.len() != self.p.len() {
208 return f64::NEG_INFINITY;
209 }
210
211 // Convert x to u64 and check if all values are non-negative integers that sum to n
212 let mut x_u64 = Vec::with_capacity(x_vec.len());
213 let mut sum = 0;
214
215 for &val in x_vec.iter() {
216 // Check if value is a non-negative integer
217 if val < 0.0 || (val - val.floor()).abs() > 1e-10 {
218 return f64::NEG_INFINITY;
219 }
220
221 let val_u64 = val as u64;
222 x_u64.push(val_u64);
223 sum += val_u64;
224 }
225
226 // Check if values sum to n
227 if sum != self.n {
228 return f64::NEG_INFINITY;
229 }
230
231 // Calculate the log multinomial PMF:
232 // log(P(X = x)) = log(n! / (x₁! * x₂! * ... * xₖ!)) + x₁*log(p₁) + x₂*log(p₂) + ... + xₖ*log(pₖ)
233
234 // Log of multinomial coefficient
235 let log_coef = factorial(self.n).ln();
236 let mut log_denom = 0.0;
237 for &count in &x_u64 {
238 log_denom += factorial(count).ln();
239 }
240
241 // Sum of x_i*log(p_i)
242 let mut log_prob_sum = 0.0;
243 for (i, &count) in x_u64.iter().enumerate() {
244 if count > 0 {
245 log_prob_sum += (count as f64) * self.p[i].ln();
246 }
247 }
248
249 log_coef - log_denom + log_prob_sum
250 }
251
252 /// Generate random samples from the distribution
253 ///
254 /// # Arguments
255 ///
256 /// * `size` - Number of samples to generate
257 ///
258 /// # Returns
259 ///
260 /// * Vector of random samples (each sample is a vector of counts)
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use scirs2_core::ndarray::array;
266 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
267 ///
268 /// let n = 10;
269 /// let p = array![0.2, 0.3, 0.5];
270 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
271 ///
272 /// // Generate 5 random samples
273 /// let samples = multinomial.rvs(5).expect("Operation failed");
274 /// assert_eq!(samples.len(), 5);
275 /// assert_eq!(samples[0].len(), 3);
276 /// ```
277 pub fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
278 let mut rng = thread_rng();
279 let mut samples = Vec::with_capacity(size);
280 let k = self.p.len();
281
282 for _ in 0..size {
283 // Initialize counts to zero
284 let mut counts = vec![0u64; k];
285
286 // Simulate n trials
287 for _ in 0..self.n {
288 // Sample category using cumulative probability
289 let u: f64 = rng.random();
290 let mut cumulative = 0.0;
291 let mut category = 0;
292 for (i, &prob) in self.p.iter().enumerate() {
293 cumulative += prob;
294 if u <= cumulative {
295 category = i;
296 break;
297 }
298 }
299 counts[category] += 1;
300 }
301
302 // Convert to floating-point array for consistency with other distributions
303 let sample = Array1::from_iter(counts.iter().map(|&x| x as f64));
304 samples.push(sample);
305 }
306
307 Ok(samples)
308 }
309
310 /// Generate a single random sample from the distribution
311 ///
312 /// # Returns
313 ///
314 /// * A random sample (a vector of counts)
315 ///
316 /// # Examples
317 ///
318 /// ```
319 /// use scirs2_core::ndarray::array;
320 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
321 ///
322 /// let n = 10;
323 /// let p = array![0.2, 0.3, 0.5];
324 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
325 ///
326 /// // Generate a single random sample
327 /// let sample = multinomial.rvs_single().expect("Operation failed");
328 /// assert_eq!(sample.len(), 3);
329 /// ```
330 pub fn rvs_single(&self) -> StatsResult<Array1<f64>> {
331 let samples = self.rvs(1)?;
332 Ok(samples[0].clone())
333 }
334
335 /// Calculate the mean of the distribution
336 ///
337 /// # Returns
338 ///
339 /// * Mean vector (n * p)
340 ///
341 /// # Examples
342 ///
343 /// ```
344 /// use scirs2_core::ndarray::array;
345 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
346 ///
347 /// let n = 10;
348 /// let p = array![0.2, 0.3, 0.5];
349 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
350 ///
351 /// let mean = multinomial.mean();
352 /// // Mean should be [2.0, 3.0, 5.0]
353 /// ```
354 pub fn mean(&self) -> Array1<f64> {
355 let n_f64 = self.n as f64;
356 self.p.mapv(|p_i| n_f64 * p_i)
357 }
358
359 /// Calculate the covariance matrix of the distribution
360 ///
361 /// # Returns
362 ///
363 /// * Covariance matrix
364 ///
365 /// # Examples
366 ///
367 /// ```
368 /// use scirs2_core::ndarray::array;
369 /// use scirs2_stats::distributions::multivariate::multinomial::Multinomial;
370 ///
371 /// let n = 10;
372 /// let p = array![0.2, 0.3, 0.5];
373 /// let multinomial = Multinomial::new(n, p).expect("Operation failed");
374 ///
375 /// let cov = multinomial.cov();
376 /// ```
377 pub fn cov(&self) -> scirs2_core::ndarray::Array2<f64> {
378 let k = self.p.len();
379 let n_f64 = self.n as f64;
380 let mut cov = scirs2_core::ndarray::Array2::zeros((k, k));
381
382 // Fill the covariance matrix
383 // Diagonal: n*p_i*(1-p_i)
384 // Off-diagonal: -n*p_i*p_j
385 for i in 0..k {
386 for j in 0..k {
387 if i == j {
388 cov[[i, j]] = n_f64 * self.p[i] * (1.0 - self.p[i]);
389 } else {
390 cov[[i, j]] = -n_f64 * self.p[i] * self.p[j];
391 }
392 }
393 }
394
395 cov
396 }
397}
398
399/// Create a Multinomial distribution with the given parameters.
400///
401/// This is a convenience function to create a Multinomial distribution with
402/// the given number of trials and probability vector.
403///
404/// # Arguments
405///
406/// * `n` - Number of trials
407/// * `p` - Probability of each outcome (must sum to 1)
408///
409/// # Returns
410///
411/// * A Multinomial distribution object
412///
413/// # Examples
414///
415/// ```
416/// use scirs2_core::ndarray::array;
417/// use scirs2_stats::distributions::multivariate;
418///
419/// let n = 10;
420/// let p = array![0.2, 0.3, 0.5]; // Probabilities for each outcome
421/// let multinomial = multivariate::multinomial(n, p).expect("Operation failed");
422/// ```
423#[allow(dead_code)]
424pub fn multinomial<D>(n: u64, p: ArrayBase<D, Ix1>) -> StatsResult<Multinomial>
425where
426 D: Data<Elem = f64>,
427{
428 Multinomial::new(n, p)
429}
430
431/// Implementation of SampleableDistribution for Multinomial
432impl SampleableDistribution<Array1<f64>> for Multinomial {
433 fn rvs(&self, size: usize) -> StatsResult<Vec<Array1<f64>>> {
434 self.rvs(size)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use approx::assert_relative_eq;
442 use scirs2_core::ndarray::array;
443
444 #[test]
445 fn test_multinomial_creation() {
446 // Valid multinomial
447 let n = 10;
448 let p = array![0.2, 0.3, 0.5];
449 let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
450 assert_eq!(multinomial.n, n);
451 assert_eq!(multinomial.p, p);
452
453 // Invalid probabilities (don't sum to 1)
454 let p_invalid_sum = array![0.2, 0.3, 0.6]; // sum = 1.1
455 assert!(Multinomial::new(n, p_invalid_sum).is_err());
456
457 // Invalid probabilities (negative values)
458 let p_negative = array![0.2, -0.1, 0.9];
459 assert!(Multinomial::new(n, p_negative).is_err());
460 }
461
462 #[test]
463 fn test_multinomial_pmf() {
464 let n = 5;
465 let p = array![0.5, 0.5];
466 let multinomial = Multinomial::new(n, p).expect("Operation failed");
467
468 // PMF at x = [2, 3]
469 let x1 = array![2.0, 3.0];
470 let pmf1 = multinomial.pmf(&x1);
471
472 // Calculate expected PMF: 5!/(2!*3!) * 0.5^2 * 0.5^3 = 10 * 0.25 * 0.125 = 0.3125
473 let expected_pmf1 = 0.3125;
474 assert_relative_eq!(pmf1, expected_pmf1, epsilon = 1e-10);
475
476 // PMF at x = [5, 0]
477 let x2 = array![5.0, 0.0];
478 let pmf2 = multinomial.pmf(&x2);
479
480 // Calculate expected PMF: 5!/(5!*0!) * 0.5^5 * 0.5^0 = 1 * 0.03125 * 1 = 0.03125
481 let expected_pmf2 = 0.03125;
482 assert_relative_eq!(pmf2, expected_pmf2, epsilon = 1e-10);
483
484 // PMF at invalid x (doesn't sum to n)
485 let x_invalid = array![2.0, 2.0]; // sum = 4 != 5
486 let pmf_invalid = multinomial.pmf(&x_invalid);
487 assert_eq!(pmf_invalid, 0.0);
488
489 // PMF at invalid x (non-integer values)
490 let x_non_int = array![2.5, 2.5];
491 let pmf_non_int = multinomial.pmf(&x_non_int);
492 assert_eq!(pmf_non_int, 0.0);
493
494 // PMF at invalid x (wrong dimension)
495 let x_wrong_dim = array![2.0, 3.0, 0.0];
496 let pmf_wrong_dim = multinomial.pmf(&x_wrong_dim);
497 assert_eq!(pmf_wrong_dim, 0.0);
498 }
499
500 #[test]
501 fn test_multinomial_logpmf() {
502 let n = 5;
503 let p = array![0.5, 0.5];
504 let multinomial = Multinomial::new(n, p).expect("Operation failed");
505
506 // LogPMF at x = [2, 3]
507 let x1 = array![2.0, 3.0];
508 let logpmf1 = multinomial.logpmf(&x1);
509 let pmf1 = multinomial.pmf(&x1);
510
511 // Check that exp(logPMF) = PMF
512 assert_relative_eq!(logpmf1.exp(), pmf1, epsilon = 1e-10);
513
514 // LogPMF at invalid x (doesn't sum to n)
515 let x_invalid = array![2.0, 2.0]; // sum = 4 != 5
516 let logpmf_invalid = multinomial.logpmf(&x_invalid);
517 assert_eq!(logpmf_invalid, f64::NEG_INFINITY);
518 }
519
520 #[test]
521 fn test_multinomial_mean() {
522 let n = 10;
523 let p = array![0.2, 0.3, 0.5];
524 let multinomial = Multinomial::new(n, p).expect("Operation failed");
525
526 let mean = multinomial.mean();
527 let expected_mean = array![2.0, 3.0, 5.0];
528
529 for i in 0..3 {
530 assert_relative_eq!(mean[i], expected_mean[i], epsilon = 1e-10);
531 }
532 }
533
534 #[test]
535 fn test_multinomial_cov() {
536 let n = 10;
537 let p = array![0.2, 0.3, 0.5];
538 let multinomial = Multinomial::new(n, p).expect("Operation failed");
539
540 let cov = multinomial.cov();
541
542 // Expected covariance matrix:
543 // [n*p1*(1-p1), -n*p1*p2, -n*p1*p3]
544 // [-n*p2*p1, n*p2*(1-p2), -n*p2*p3]
545 // [-n*p3*p1, -n*p3*p2, n*p3*(1-p3)]
546
547 // Diagonal elements
548 assert_relative_eq!(cov[[0, 0]], 10.0 * 0.2 * 0.8, epsilon = 1e-10); // 1.6
549 assert_relative_eq!(cov[[1, 1]], 10.0 * 0.3 * 0.7, epsilon = 1e-10); // 2.1
550 assert_relative_eq!(cov[[2, 2]], 10.0 * 0.5 * 0.5, epsilon = 1e-10); // 2.5
551
552 // Off-diagonal elements
553 assert_relative_eq!(cov[[0, 1]], -10.0 * 0.2 * 0.3, epsilon = 1e-10); // -0.6
554 assert_relative_eq!(cov[[0, 2]], -10.0 * 0.2 * 0.5, epsilon = 1e-10); // -1.0
555 assert_relative_eq!(cov[[1, 2]], -10.0 * 0.3 * 0.5, epsilon = 1e-10); // -1.5
556
557 // Symmetry
558 assert_relative_eq!(cov[[1, 0]], cov[[0, 1]], epsilon = 1e-10);
559 assert_relative_eq!(cov[[2, 0]], cov[[0, 2]], epsilon = 1e-10);
560 assert_relative_eq!(cov[[2, 1]], cov[[1, 2]], epsilon = 1e-10);
561 }
562
563 #[test]
564 fn test_multinomial_rvs() {
565 let n = 100;
566 let p = array![0.2, 0.3, 0.5];
567 let multinomial = Multinomial::new(n, p.clone()).expect("Operation failed");
568
569 // Generate samples
570 let num_samples = 100;
571 let samples = multinomial.rvs(num_samples).expect("Operation failed");
572
573 // Check the number of samples
574 assert_eq!(samples.len(), num_samples);
575
576 // Check the dimension of each sample
577 for sample in &samples {
578 assert_eq!(sample.len(), 3);
579
580 // Check that each sample sums to n
581 let sum: f64 = sample.sum();
582 assert_eq!(sum, n as f64);
583 }
584
585 // Calculate sample means
586 let mut sample_sum = array![0.0, 0.0, 0.0];
587 for sample in &samples {
588 sample_sum += sample;
589 }
590 let sample_mean = sample_sum / num_samples as f64;
591
592 // Expected means
593 let expected_mean = array![20.0, 30.0, 50.0];
594
595 // Check that sample means are reasonably close to expected means
596 // (using larger tolerance due to random sampling)
597 for i in 0..3 {
598 assert!((sample_mean[i] - expected_mean[i]).abs() < 5.0);
599 }
600 }
601
602 #[test]
603 fn test_multinomial_rvs_single() {
604 let n = 10;
605 let p = array![0.2, 0.3, 0.5];
606 let multinomial = Multinomial::new(n, p).expect("Operation failed");
607
608 let sample = multinomial.rvs_single().expect("Operation failed");
609
610 // Check the dimension of the sample
611 assert_eq!(sample.len(), 3);
612
613 // Check that the sample sums to n
614 let sum: f64 = sample.sum();
615 assert_eq!(sum, n as f64);
616 }
617
618 #[test]
619 fn test_multinomial_coef() {
620 // (5 choose 2,3) = 5! / (2! * 3!)
621 let coef1 = multinomial_coef(5, &[2, 3]);
622 assert_eq!(coef1, 10.0);
623
624 // (8 choose 3,2,3) = 8! / (3! * 2! * 3!)
625 let coef2 = multinomial_coef(8, &[3, 2, 3]);
626 assert_eq!(coef2, 560.0);
627 }
628}