Skip to main content

scirs2_stats/distributions/
bernoulli.rs

1//! Bernoulli distribution functions
2//!
3//! This module provides functionality for the Bernoulli distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::numeric::{Float, NumCast};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::{Bernoulli as RandBernoulli, Distribution};
10use scirs2_core::validation::check_probability;
11
12/// Bernoulli distribution structure
13///
14/// The Bernoulli distribution is a discrete probability distribution taking
15/// value 1 with probability p and value 0 with probability q = 1 - p.
16/// It is the discrete probability distribution of a random variable which takes
17/// the value 1 with probability p and the value 0 with probability q.
18pub struct Bernoulli<F: Float> {
19    /// Success probability p (0 ≤ p ≤ 1)
20    pub p: F,
21    /// Random number generator
22    rand_distr: RandBernoulli,
23}
24
25impl<F: Float + NumCast + std::fmt::Display> Bernoulli<F> {
26    /// Create a new Bernoulli distribution with given success probability
27    ///
28    /// # Arguments
29    ///
30    /// * `p` - Success probability (0 ≤ p ≤ 1)
31    ///
32    /// # Returns
33    ///
34    /// * A new Bernoulli distribution instance
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
40    ///
41    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
42    /// ```
43    pub fn new(p: F) -> StatsResult<Self> {
44        // Validate parameters using core validation function
45        let _ = check_probability(p, "Success probability").map_err(StatsError::from)?;
46
47        // Create RNG for Bernoulli distribution
48        let p_f64 = <f64 as scirs2_core::numeric::NumCast>::from(p).ok_or_else(|| {
49            StatsError::ComputationError("Failed to convert p to f64".to_string())
50        })?;
51        let rand_distr = match RandBernoulli::new(p_f64) {
52            Ok(distr) => distr,
53            Err(_) => {
54                return Err(StatsError::ComputationError(
55                    "Failed to create Bernoulli distribution for sampling".to_string(),
56                ))
57            }
58        };
59
60        Ok(Bernoulli { p, rand_distr })
61    }
62
63    /// Calculate the probability mass function (PMF) at a given point
64    ///
65    /// # Arguments
66    ///
67    /// * `k` - The point at which to evaluate the PMF (0 or 1)
68    ///
69    /// # Returns
70    ///
71    /// * The value of the PMF at the given point
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
77    ///
78    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
79    /// let pmf_at_one = bern.pmf(1.0);
80    /// assert!((pmf_at_one - 0.3).abs() < 1e-7);
81    /// ```
82    pub fn pmf(&self, k: F) -> F {
83        let one = F::one();
84        let zero = F::zero();
85
86        // PMF is only defined for k = 0 and k = 1
87        if k == zero {
88            one - self.p // q = 1 - p
89        } else if k == one {
90            self.p
91        } else {
92            zero
93        }
94    }
95
96    /// Calculate the log of the probability mass function (log-PMF) at a given point
97    ///
98    /// # Arguments
99    ///
100    /// * `k` - The point at which to evaluate the log-PMF (0 or 1)
101    ///
102    /// # Returns
103    ///
104    /// * The value of the log-PMF at the given point
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
110    ///
111    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
112    /// let log_pmf_at_one = bern.log_pmf(1.0);
113    /// assert!((log_pmf_at_one - (-1.2039728)).abs() < 1e-6);
114    /// ```
115    pub fn log_pmf(&self, k: F) -> F {
116        let one = F::one();
117        let zero = F::zero();
118        let neg_infinity = F::neg_infinity();
119
120        // log-PMF is only defined for k = 0 and k = 1
121        if k == zero {
122            if self.p == one {
123                neg_infinity
124            } else {
125                (one - self.p).ln() // ln(q) = ln(1 - p)
126            }
127        } else if k == one {
128            if self.p == zero {
129                neg_infinity
130            } else {
131                self.p.ln() // ln(p)
132            }
133        } else {
134            neg_infinity
135        }
136    }
137
138    /// Calculate the cumulative distribution function (CDF) at a given point
139    ///
140    /// # Arguments
141    ///
142    /// * `k` - The point at which to evaluate the CDF
143    ///
144    /// # Returns
145    ///
146    /// * The value of the CDF at the given point
147    ///
148    /// # Examples
149    ///
150    /// ```
151    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
152    ///
153    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
154    /// let cdf_at_zero = bern.cdf(0.0);
155    /// assert!((cdf_at_zero - 0.7).abs() < 1e-7);
156    /// ```
157    pub fn cdf(&self, k: F) -> F {
158        let zero = F::zero();
159        let one = F::one();
160
161        if k < zero {
162            zero
163        } else if k < one {
164            one - self.p // F(0) = P(X ≤ 0) = P(X = 0) = 1 - p
165        } else {
166            one // F(k) = P(X ≤ k) = 1 for k ≥ 1
167        }
168    }
169
170    /// Inverse of the cumulative distribution function (quantile function)
171    ///
172    /// # Arguments
173    ///
174    /// * `p` - Probability value (between 0 and 1)
175    ///
176    /// # Returns
177    ///
178    /// * The value k such that CDF(k) = p
179    ///
180    /// # Examples
181    ///
182    /// ```
183    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
184    ///
185    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
186    /// let quant = bern.ppf(0.8).expect("Operation failed");
187    /// assert_eq!(quant, 1.0);
188    /// ```
189    pub fn ppf(&self, p_val: F) -> StatsResult<F> {
190        // Validate probability using core validation function
191        let p_val = check_probability(p_val, "Probability value").map_err(StatsError::from)?;
192
193        let zero = F::zero();
194        let one = F::one();
195
196        // Quantile function for Bernoulli
197        let q = one - self.p; // q = 1 - p
198
199        if p_val <= q {
200            Ok(zero) // Q(p) = 0 for p ≤ q
201        } else {
202            Ok(one) // Q(p) = 1 for p > q
203        }
204    }
205
206    /// Generate random samples from the distribution
207    ///
208    /// # Arguments
209    ///
210    /// * `size` - Number of samples to generate
211    ///
212    /// # Returns
213    ///
214    /// * Vector of random samples
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
220    ///
221    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
222    /// let samples = bern.rvs(10).expect("Operation failed");
223    /// assert_eq!(samples.len(), 10);
224    /// ```
225    pub fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
226        let mut rng = thread_rng();
227        let mut samples = Vec::with_capacity(size);
228        let zero = F::zero();
229        let one = F::one();
230
231        for _ in 0..size {
232            // Generate random Bernoulli sample (0 or 1)
233            let sample = if self.rand_distr.sample(&mut rng) {
234                one
235            } else {
236                zero
237            };
238
239            samples.push(sample);
240        }
241
242        Ok(samples)
243    }
244
245    /// Calculate the mean of the distribution
246    ///
247    /// # Returns
248    ///
249    /// * The mean of the distribution
250    ///
251    /// # Examples
252    ///
253    /// ```
254    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
255    ///
256    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
257    /// let mean = bern.mean();
258    /// assert!((mean - 0.3).abs() < 1e-7);
259    /// ```
260    pub fn mean(&self) -> F {
261        // Mean = p
262        self.p
263    }
264
265    /// Calculate the variance of the distribution
266    ///
267    /// # Returns
268    ///
269    /// * The variance of the distribution
270    ///
271    /// # Examples
272    ///
273    /// ```
274    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
275    ///
276    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
277    /// let variance = bern.var();
278    /// assert!((variance - 0.21).abs() < 1e-7);
279    /// ```
280    pub fn var(&self) -> F {
281        // Variance = p * (1 - p)
282        let one = F::one();
283        self.p * (one - self.p)
284    }
285
286    /// Calculate the standard deviation of the distribution
287    ///
288    /// # Returns
289    ///
290    /// * The standard deviation of the distribution
291    ///
292    /// # Examples
293    ///
294    /// ```
295    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
296    ///
297    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
298    /// let std_dev = bern.std();
299    /// assert!((std_dev - 0.458257).abs() < 1e-6);
300    /// ```
301    pub fn std(&self) -> F {
302        // Std = sqrt(variance)
303        self.var().sqrt()
304    }
305
306    /// Calculate the skewness of the distribution
307    ///
308    /// # Returns
309    ///
310    /// * The skewness of the distribution
311    ///
312    /// # Examples
313    ///
314    /// ```
315    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
316    ///
317    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
318    /// let skewness = bern.skewness();
319    /// assert!((skewness - 0.87287156).abs() < 1e-5);
320    /// ```
321    pub fn skewness(&self) -> F {
322        // Skewness = (1 - 2p) / sqrt(p * (1 - p))
323        let one = F::from(1.0).unwrap_or_else(|| F::zero());
324        let two = F::from(2.0).unwrap_or_else(|| F::zero());
325
326        let q = one - self.p; // q = 1 - p
327
328        // Handle special cases to avoid division by zero
329        if self.p == F::zero() || self.p == F::one() {
330            return F::zero(); // Degenerate case, skewness is not well-defined
331        }
332
333        (one - two * self.p) / (self.p * q).sqrt()
334    }
335
336    /// Calculate the kurtosis of the distribution
337    ///
338    /// # Returns
339    ///
340    /// * The excess kurtosis of the distribution
341    ///
342    /// # Examples
343    ///
344    /// ```
345    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
346    ///
347    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
348    /// let kurtosis = bern.kurtosis();
349    /// assert!((kurtosis - (-1.2351)) < 1e-4);
350    /// ```
351    pub fn kurtosis(&self) -> F {
352        // Excess Kurtosis = (1 - 6p(1-p)) / (p(1-p))
353        let one = F::from(1.0).unwrap_or_else(|| F::zero());
354        let six = F::from(6.0).unwrap_or_else(|| F::zero());
355
356        let q = one - self.p; // q = 1 - p
357        let pq = self.p * q;
358
359        // Handle special cases to avoid division by zero
360        if self.p == F::zero() || self.p == F::one() {
361            return F::zero(); // Degenerate case, kurtosis is not well-defined
362        }
363
364        (one - six * pq) / pq
365    }
366
367    /// Calculate the entropy of the distribution
368    ///
369    /// # Returns
370    ///
371    /// * The entropy value
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
377    ///
378    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
379    /// let entropy = bern.entropy();
380    /// assert!((entropy - 0.6108643).abs() < 1e-6);
381    /// ```
382    pub fn entropy(&self) -> F {
383        // Entropy = -p * ln(p) - (1-p) * ln(1-p)
384        let zero = F::zero();
385        let one = F::one();
386
387        // Handle special cases
388        if self.p == zero || self.p == one {
389            return zero; // Degenerate case, entropy is 0
390        }
391
392        let q = one - self.p; // q = 1 - p
393
394        // H(X) = -p * ln(p) - q * ln(q)
395        -(self.p * self.p.ln() + q * q.ln())
396    }
397
398    /// Calculate the median of the distribution
399    ///
400    /// # Returns
401    ///
402    /// * The median of the distribution
403    ///
404    /// # Examples
405    ///
406    /// ```
407    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
408    ///
409    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
410    /// let median = bern.median();
411    /// assert_eq!(median, 0.0);
412    /// ```
413    pub fn median(&self) -> F {
414        let zero = F::zero();
415        let one = F::one();
416        let half = F::from(0.5).expect("Failed to convert constant to float");
417
418        // Median is 0 if p < 0.5, 0 or 1 if p = 0.5, and 1 if p > 0.5
419        if self.p < half {
420            zero
421        } else if self.p > half {
422            one
423        } else {
424            // When p = 0.5, both 0 and 1 are medians
425            // We return 0 by convention
426            zero
427        }
428    }
429
430    /// Calculate the mode of the distribution
431    ///
432    /// # Returns
433    ///
434    /// * The mode of the distribution
435    ///
436    /// # Examples
437    ///
438    /// ```
439    /// use scirs2_stats::distributions::bernoulli::Bernoulli;
440    ///
441    /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
442    /// let mode = bern.mode();
443    /// assert_eq!(mode, 0.0);
444    /// ```
445    pub fn mode(&self) -> F {
446        let zero = F::zero();
447        let one = F::one();
448        let half = F::from(0.5).expect("Failed to convert constant to float");
449
450        // Mode is 0 if p < 0.5, 0 or 1 if p = 0.5, and 1 if p > 0.5
451        if self.p < half {
452            zero
453        } else if self.p > half {
454            one
455        } else {
456            // When p = 0.5, both 0 and 1 are modes
457            // We return 0 by convention
458            zero
459        }
460    }
461}
462
463/// Create a Bernoulli distribution with the given parameter.
464///
465/// This is a convenience function to create a Bernoulli distribution with
466/// the given success probability.
467///
468/// # Arguments
469///
470/// * `p` - Success probability (0 ≤ p ≤ 1)
471///
472/// # Returns
473///
474/// * A Bernoulli distribution object
475///
476/// # Examples
477///
478/// ```
479/// use scirs2_stats::distributions::bernoulli;
480///
481/// let b = bernoulli::bernoulli(0.3f64).expect("Operation failed");
482/// let pmf_at_one = b.pmf(1.0);
483/// assert!((pmf_at_one - 0.3).abs() < 1e-7);
484/// ```
485#[allow(dead_code)]
486pub fn bernoulli<F>(p: F) -> StatsResult<Bernoulli<F>>
487where
488    F: Float + NumCast + std::fmt::Display,
489{
490    Bernoulli::new(p)
491}
492
493/// Implementation of SampleableDistribution for Bernoulli
494impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Bernoulli<F> {
495    fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
496        self.rvs(size)
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use approx::assert_relative_eq;
504
505    #[test]
506    fn test_bernoulli_creation() {
507        // Valid p values
508        let bern1 = Bernoulli::new(0.0).expect("Operation failed");
509        assert_eq!(bern1.p, 0.0);
510
511        let bern2 = Bernoulli::new(0.5).expect("Operation failed");
512        assert_eq!(bern2.p, 0.5);
513
514        let bern3 = Bernoulli::new(1.0).expect("Operation failed");
515        assert_eq!(bern3.p, 1.0);
516
517        // Invalid p values
518        assert!(Bernoulli::<f64>::new(-0.1).is_err());
519        assert!(Bernoulli::<f64>::new(1.1).is_err());
520    }
521
522    #[test]
523    fn test_bernoulli_pmf() {
524        let bern = Bernoulli::new(0.3).expect("Operation failed");
525
526        // PMF at k = 0
527        let pmf_at_zero = bern.pmf(0.0);
528        assert_relative_eq!(pmf_at_zero, 0.7, epsilon = 1e-10);
529
530        // PMF at k = 1
531        let pmf_at_one = bern.pmf(1.0);
532        assert_relative_eq!(pmf_at_one, 0.3, epsilon = 1e-10);
533
534        // PMF at other values (should be 0)
535        let pmf_at_other = bern.pmf(0.5);
536        assert_eq!(pmf_at_other, 0.0);
537
538        // Corner cases
539        let bern_zero = Bernoulli::new(0.0).expect("Operation failed");
540        assert_eq!(bern_zero.pmf(0.0), 1.0);
541        assert_eq!(bern_zero.pmf(1.0), 0.0);
542
543        let bern_one = Bernoulli::new(1.0).expect("Operation failed");
544        assert_eq!(bern_one.pmf(0.0), 0.0);
545        assert_eq!(bern_one.pmf(1.0), 1.0);
546    }
547
548    #[test]
549    fn test_bernoulli_log_pmf() {
550        let bern = Bernoulli::new(0.3).expect("Operation failed");
551
552        // log-PMF at k = 0
553        let log_pmf_at_zero = bern.log_pmf(0.0);
554        assert_relative_eq!(log_pmf_at_zero, 0.7.ln(), epsilon = 1e-10);
555
556        // log-PMF at k = 1
557        let log_pmf_at_one = bern.log_pmf(1.0);
558        assert_relative_eq!(log_pmf_at_one, 0.3.ln(), epsilon = 1e-10);
559
560        // log-PMF at other values (should be -infinity)
561        let log_pmf_at_other = bern.log_pmf(0.5);
562        assert!(log_pmf_at_other.is_infinite() && log_pmf_at_other.is_sign_negative());
563
564        // Corner cases
565        let bern_zero = Bernoulli::new(0.0).expect("Operation failed");
566        assert_eq!(bern_zero.log_pmf(0.0), 0.0);
567        assert!(bern_zero.log_pmf(1.0).is_infinite() && bern_zero.log_pmf(1.0).is_sign_negative());
568
569        let bern_one = Bernoulli::new(1.0).expect("Operation failed");
570        assert!(bern_one.log_pmf(0.0).is_infinite() && bern_one.log_pmf(0.0).is_sign_negative());
571        assert_eq!(bern_one.log_pmf(1.0), 0.0);
572    }
573
574    #[test]
575    fn test_bernoulli_cdf() {
576        let bern = Bernoulli::new(0.3).expect("Operation failed");
577
578        // CDF for various values
579        assert_eq!(bern.cdf(-0.1), 0.0); // F(-0.1) = 0
580        assert_eq!(bern.cdf(0.0), 0.7); // F(0) = P(X ≤ 0) = P(X = 0) = 1 - p = 0.7
581        assert_eq!(bern.cdf(0.5), 0.7); // F(0.5) = P(X ≤ 0.5) = P(X = 0) = 1 - p = 0.7
582        assert_eq!(bern.cdf(1.0), 1.0); // F(1) = P(X ≤ 1) = 1
583        assert_eq!(bern.cdf(2.0), 1.0); // F(2) = P(X ≤ 2) = 1
584    }
585
586    #[test]
587    fn test_bernoulli_ppf() {
588        let bern = Bernoulli::new(0.3).expect("Operation failed");
589
590        // Quantile function
591        assert_eq!(bern.ppf(0.0).expect("Operation failed"), 0.0); // Q(0) = 0
592        assert_eq!(bern.ppf(0.3).expect("Operation failed"), 0.0); // Q(0.3) = 0 since 0.3 ≤ q = 0.7
593        assert_eq!(bern.ppf(0.7).expect("Operation failed"), 0.0); // Q(0.7) = 0 since 0.7 = q = 0.7
594        assert_eq!(bern.ppf(0.71).expect("Operation failed"), 1.0); // Q(0.71) = 1 since 0.71 > q = 0.7
595        assert_eq!(bern.ppf(1.0).expect("Operation failed"), 1.0); // Q(1) = 1
596
597        // Invalid p values
598        assert!(bern.ppf(-0.1).is_err());
599        assert!(bern.ppf(1.1).is_err());
600    }
601
602    #[test]
603    fn test_bernoulli_rvs() {
604        let bern = Bernoulli::new(0.5).expect("Operation failed");
605
606        // Generate samples
607        let samples = bern.rvs(100).expect("Operation failed");
608
609        // Check the number of samples
610        assert_eq!(samples.len(), 100);
611
612        // Check all values are either 0 or 1
613        for &sample in &samples {
614            assert!(sample == 0.0 || sample == 1.0);
615        }
616
617        // With p = 0.5, mean should be close to 0.5 for a large sample
618        let sum: f64 = samples.iter().sum();
619        let mean = sum / samples.len() as f64;
620
621        // Allow for some randomness, but mean should be roughly around 0.5
622        assert!(mean > 0.3 && mean < 0.7);
623    }
624
625    #[test]
626    fn test_bernoulli_stats() {
627        // Test with p = 0.3
628        let bern = Bernoulli::new(0.3).expect("Operation failed");
629
630        // Mean = p = 0.3
631        assert_eq!(bern.mean(), 0.3);
632
633        // Variance = p * (1 - p) = 0.3 * 0.7 = 0.21
634        assert_relative_eq!(bern.var(), 0.21, epsilon = 1e-10);
635
636        // Standard deviation = sqrt(variance) = sqrt(0.21) ≈ 0.458258
637        assert_relative_eq!(bern.std(), 0.21_f64.sqrt(), epsilon = 1e-10);
638
639        // Skewness = (1 - 2p) / sqrt(p * (1 - p)) = (1 - 2*0.3) / sqrt(0.3 * 0.7) = 0.4 / sqrt(0.21) ≈ 0.872872
640        let expected_skewness = (1.0 - 2.0 * 0.3) / (0.3 * 0.7).sqrt();
641        assert_relative_eq!(bern.skewness(), expected_skewness, epsilon = 1e-5);
642
643        // Kurtosis = (1 - 6p(1-p)) / (p(1-p)) = (1 - 6*0.3*0.7) / (0.3*0.7) = (1 - 1.26) / 0.21 ≈ -1.238
644        let expected_kurtosis = (1.0 - 6.0 * 0.3 * 0.7) / (0.3 * 0.7);
645        assert_relative_eq!(bern.kurtosis(), expected_kurtosis, epsilon = 1e-6);
646
647        // Entropy = -p * ln(p) - (1-p) * ln(1-p) = -0.3 * ln(0.3) - 0.7 * ln(0.7) ≈ 0.610864
648        let expected_entropy = -0.3 * 0.3.ln() - 0.7 * 0.7.ln();
649        assert_relative_eq!(bern.entropy(), expected_entropy, epsilon = 1e-6);
650
651        // Median and mode for p < 0.5 are both 0
652        assert_eq!(bern.median(), 0.0);
653        assert_eq!(bern.mode(), 0.0);
654
655        // Test with p = 0.8 (> 0.5)
656        let bern2 = Bernoulli::new(0.8).expect("Operation failed");
657
658        // Median and mode for p > 0.5 are both 1
659        assert_eq!(bern2.median(), 1.0);
660        assert_eq!(bern2.mode(), 1.0);
661
662        // Test with p = 0.5
663        let bern3 = Bernoulli::new(0.5).expect("Operation failed");
664
665        // Median and mode for p = 0.5 are either 0 or 1 (we return 0 by convention)
666        assert_eq!(bern3.median(), 0.0);
667        assert_eq!(bern3.mode(), 0.0);
668    }
669}