single_statistics/testing/inference/
discrete.rs1use crate::testing::{Alternative, TestResult};
7use single_utilities::traits::FloatOpsTS;
8use statrs::distribution::{ChiSquared, ContinuousCDF, Discrete, DiscreteCDF};
9use crate::testing::utils::SparseMatrixRef;
10use num_traits::{AsPrimitive, Float};
11use rayon::prelude::*;
12
13pub fn chi_square_test<T>(
15 a: T,
16 b: T,
17 c: T,
18 d: T,
19 alternative: Alternative,
20) -> TestResult<T>
21where
22 T: FloatOpsTS,
23{
24 let total = a + b + c + d;
25 if total <= T::zero() {
26 return TestResult::new(T::zero(), T::one());
27 }
28
29 let row1 = a + b;
31 let row2 = c + d;
32 let col1 = a + c;
33 let col2 = b + d;
34
35 let expected_a = (row1 * col1) / total;
36 let expected_b = (row1 * col2) / total;
37 let expected_c = (row2 * col1) / total;
38 let expected_d = (row2 * col2) / total;
39
40 let chi_square = (Float::powi(a - expected_a, 2) / expected_a)
42 + (Float::powi(b - expected_b, 2) / expected_b)
43 + (Float::powi(c - expected_c, 2) / expected_c)
44 + (Float::powi(d - expected_d, 2) / expected_d);
45
46 let p_value = calculate_chi_square_p_value(chi_square, T::one(), alternative);
48
49 TestResult::new(chi_square, p_value)
50}
51
52fn calculate_chi_square_p_value<T>(chi_square: T, df: T, alternative: Alternative) -> T
53where
54 T: FloatOpsTS,
55{
56 let chi_square_f64 = chi_square.to_f64().unwrap();
57 let df_f64 = df.to_f64().unwrap();
58
59 match ChiSquared::new(df_f64) {
60 Ok(chi_dist) => {
61 let p = match alternative {
62 Alternative::TwoSided => 1.0 - chi_dist.cdf(chi_square_f64), Alternative::Less => chi_dist.cdf(chi_square_f64),
64 Alternative::Greater => 1.0 - chi_dist.cdf(chi_square_f64),
65 };
66 T::from(p).unwrap()
67 }
68 Err(_) => T::one(),
69 }
70}
71
72pub fn fisher_exact_test<T>(
82 a: usize,
83 b: usize,
84 c: usize,
85 d: usize,
86 _alternative: Alternative,
87) -> TestResult<T>
88where
89 T: FloatOpsTS,
90{
91 use statrs::distribution::Hypergeometric;
93
94 let n1 = a + c; let n2 = b + d; let total_expr = a + b;
97 let total_cells = n1 + n2;
98
99 if total_cells == 0 {
100 return TestResult::new(T::zero(), T::one());
101 }
102
103 match Hypergeometric::new(total_cells as u64, total_expr as u64, n1 as u64) {
106 Ok(hyper) => {
107 let p_val: f64 = match _alternative {
108 Alternative::Greater => 1.0 - hyper.cdf((a as u64).saturating_sub(1)),
109 Alternative::Less => hyper.cdf(a as u64),
110 Alternative::TwoSided => {
111 let p_a = hyper.pmf(a as u64);
112 let mut p_sum = 0.0;
113 let upper_limit = std::cmp::min(n1, total_expr);
114 for i in 0..=upper_limit {
115 let p_i = hyper.pmf(i as u64);
116 if p_i <= p_a + 1e-12 {
117 p_sum += p_i;
118 }
119 }
120 p_sum.min(1.0)
121 }
122 };
123
124 let odds_ratio = if b * c == 0 {
125 if a * d > 0 { f64::INFINITY } else { 0.0 }
126 } else {
127 (a as f64 * d as f64) / (b as f64 * c as f64)
128 };
129
130 TestResult::new(T::from(odds_ratio).unwrap(), T::from(p_val).unwrap())
131 }
132 Err(_) => TestResult::new(T::zero(), T::one()),
133 }
134}
135
136pub fn fisher_exact_sparse<T, N, I>(
138 matrix: SparseMatrixRef<T, N, I>,
139 group1_indices: &[usize],
140 group2_indices: &[usize],
141 alternative: Alternative,
142) -> anyhow::Result<Vec<TestResult<T>>>
143where
144 T: FloatOpsTS,
145 N: AsPrimitive<usize> + Send + Sync,
146 I: AsPrimitive<usize> + Send + Sync,
147{
148 let n_group1 = group1_indices.len();
149 let n_group2 = group2_indices.len();
150
151 let mut cell_groups = vec![0u8; matrix.n_cols];
152 for &idx in group1_indices { if idx < cell_groups.len() { cell_groups[idx] = 1; } }
153 for &idx in group2_indices { if idx < cell_groups.len() { cell_groups[idx] = 2; } }
154
155 let results: Vec<_> = (0..matrix.n_rows)
156 .into_par_iter()
157 .map(|row| {
158 let start = matrix.maj_ind[row].as_();
159 let end = matrix.maj_ind[row + 1].as_();
160
161 let mut a = 0; let mut b = 0; for i in start..end {
165 let col = matrix.min_ind[i].as_();
166 match cell_groups[col] {
167 1 => a += 1,
168 2 => b += 1,
169 _ => {}
170 }
171 }
172
173 let c = n_group1 - a;
174 let d = n_group2 - b;
175
176 fisher_exact_test(a, b, c, d, alternative)
177 })
178 .collect();
179
180 Ok(results)
181}