sfs_core/
utils.rs

1//! Hypergeometric distribution.
2//!
3//! Much of the code here is adapted from the implementation in statrs.
4
5use factorial::ln_factorial;
6
7/// Returns the sum of the first n - 1 terms of the harmonic series
8pub fn harmonic(n: u64) -> f64 {
9    p_harmonic(n, 1)
10}
11
12/// Returns the sum of the first n - 1 terms of the p-harmonic series
13pub fn p_harmonic(n: u64, p: u32) -> f64 {
14    (1..n).map(|i| 1.0 / (i.pow(p) as f64)).sum()
15}
16
17/// Returns the PMF of the hypergeometric distribution.
18pub fn hypergeometric_pmf(size: u64, successes: u64, draws: u64, observed: u64) -> f64 {
19    if observed > draws {
20        0.0
21    } else {
22        binomial(successes, observed) * binomial(size - successes, draws - observed)
23            / binomial(size, draws)
24    }
25}
26
27/// Returns the binomial coefficient.
28pub fn binomial(n: u64, k: u64) -> f64 {
29    if k > n {
30        0.0
31    } else {
32        (0.5 + (ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k)).exp()).floor()
33    }
34}
35
36mod factorial {
37    use std::sync::OnceLock;
38
39    use super::gamma::ln_gamma;
40
41    const MAX: usize = 170;
42    const PRECOMPUTED_LEN: usize = MAX + 1;
43
44    fn precomputed() -> &'static [f64; PRECOMPUTED_LEN] {
45        static PRECOMPUTED: OnceLock<[f64; PRECOMPUTED_LEN]> = OnceLock::new();
46
47        PRECOMPUTED.get_or_init(|| {
48            let mut precomputed = [1.0; PRECOMPUTED_LEN];
49
50            precomputed
51                .iter_mut()
52                .enumerate()
53                .skip(1)
54                .fold(1.0, |acc, (i, x)| {
55                    let factorial = acc * i as f64;
56                    *x = factorial;
57                    factorial
58                });
59
60            precomputed
61        })
62    }
63
64    pub(super) fn ln_factorial(x: u64) -> f64 {
65        precomputed()
66            .get(x as usize)
67            .map(|factorial| factorial.ln())
68            .unwrap_or_else(|| ln_gamma(x as f64 + 1.0))
69    }
70}
71
72mod gamma {
73    use std::f64::consts::{E, PI};
74
75    const LN_2_SQRT_E_OVER_PI: f64 = 0.620_782_237_635_245_2;
76    const LN_PI: f64 = 1.144_729_885_849_400_2;
77    const R: f64 = 10.900511;
78    const DK: &[f64] = &[
79        2.485_740_891_387_535_5e-5,
80        1.051_423_785_817_219_7,
81        -3.456_870_972_220_162_5,
82        4.512_277_094_668_948,
83        -2.982_852_253_235_766_4,
84        1.056_397_115_771_267,
85        -1.954_287_731_916_458_7e-1,
86        1.709_705_434_044_412e-2,
87        -5.719_261_174_043_057e-4,
88        4.633_994_733_599_057e-6,
89        -2.719_949_084_886_077_2e-9,
90    ];
91
92    pub(super) fn ln_gamma(x: f64) -> f64 {
93        if x < 0.5 {
94            let s = DK
95                .iter()
96                .enumerate()
97                .skip(1)
98                .fold(DK[0], |s, t| s + t.1 / (t.0 as f64 - x));
99
100            LN_PI
101                - (PI * x).sin().ln()
102                - s.ln()
103                - LN_2_SQRT_E_OVER_PI
104                - (0.5 - x) * ((0.5 - x + R) / E).ln()
105        } else {
106            let s = DK
107                .iter()
108                .enumerate()
109                .skip(1)
110                .fold(DK[0], |s, t| s + t.1 / (x + t.0 as f64 - 1.0));
111
112            s.ln() + LN_2_SQRT_E_OVER_PI + (x - 0.5) * ((x - 0.5 + R) / E).ln()
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_hypergeometric_pmf() {
123        assert_approx_eq!(hypergeometric_pmf(10, 7, 8, 4), 0.0, epsilon = 1e-6);
124        assert_approx_eq!(hypergeometric_pmf(10, 7, 8, 5), 0.466667, epsilon = 1e-6);
125        assert_approx_eq!(hypergeometric_pmf(10, 7, 8, 6), 0.466667, epsilon = 1e-6);
126        assert_approx_eq!(hypergeometric_pmf(10, 7, 8, 7), 0.066667, epsilon = 1e-6);
127        assert_approx_eq!(hypergeometric_pmf(10, 7, 8, 8), 0.0, epsilon = 1e-6);
128
129        assert_approx_eq!(hypergeometric_pmf(6, 2, 2, 0), 0.4, epsilon = 1e-6);
130        assert_approx_eq!(hypergeometric_pmf(6, 2, 2, 1), 0.533333, epsilon = 1e-6);
131        assert_approx_eq!(hypergeometric_pmf(6, 2, 2, 2), 0.066667, epsilon = 1e-6);
132    }
133}