Skip to main content

single_statistics/testing/inference/
discrete.rs

1//! Discrete statistical tests for single-cell data analysis.
2//! 
3//! This module implements tests for categorical and count data,
4//! such as Chi-square and Fisher's exact tests.
5
6use crate::testing::{Alternative, TestResult};
7use single_utilities::traits::FloatOpsTS;
8use statrs::distribution::{ChiSquared, ContinuousCDF, Discrete, DiscreteCDF};
9use crate::testing::utils::SparseMatrixRef;
10use num_traits::{AsPrimitive, Float};
11use rayon::prelude::*;
12
13/// Performs a chi-square test for independence on a 2x2 contingency table
14pub fn chi_square_test<T>(
15    a: T,
16    b: T,
17    c: T,
18    d: T,
19    alternative: Alternative,
20) -> TestResult<T>
21where
22    T: FloatOpsTS,
23{
24    let total = a + b + c + d;
25    if total <= T::zero() {
26        return TestResult::new(T::zero(), T::one());
27    }
28
29    // Calculate expected frequencies
30    let row1 = a + b;
31    let row2 = c + d;
32    let col1 = a + c;
33    let col2 = b + d;
34
35    let expected_a = (row1 * col1) / total;
36    let expected_b = (row1 * col2) / total;
37    let expected_c = (row2 * col1) / total;
38    let expected_d = (row2 * col2) / total;
39
40    // Calculate chi-square statistic
41    let chi_square = (Float::powi(a - expected_a, 2) / expected_a)
42        + (Float::powi(b - expected_b, 2) / expected_b)
43        + (Float::powi(c - expected_c, 2) / expected_c)
44        + (Float::powi(d - expected_d, 2) / expected_d);
45
46    // Calculate p-value using chi-square distribution with 1 degree of freedom
47    let p_value = calculate_chi_square_p_value(chi_square, T::one(), alternative);
48
49    TestResult::new(chi_square, p_value)
50}
51
52fn calculate_chi_square_p_value<T>(chi_square: T, df: T, alternative: Alternative) -> T
53where
54    T: FloatOpsTS,
55{
56    let chi_square_f64 = chi_square.to_f64().unwrap();
57    let df_f64 = df.to_f64().unwrap();
58
59    match ChiSquared::new(df_f64) {
60        Ok(chi_dist) => {
61            let p = match alternative {
62                Alternative::TwoSided => 1.0 - chi_dist.cdf(chi_square_f64), // Chi-square is usually 1-tailed
63                Alternative::Less => chi_dist.cdf(chi_square_f64),
64                Alternative::Greater => 1.0 - chi_dist.cdf(chi_square_f64),
65            };
66            T::from(p).unwrap()
67        }
68        Err(_) => T::one(),
69    }
70}
71
72/// Fisher's Exact Test for 2x2 contingency table.
73/// 
74/// Hypergeometric distribution: 
75/// N: total balls, K: total white balls, n: balls drawn, k: white balls drawn
76/// 
77/// Contingency table:
78///         Group1  Group2
79/// Expr      a       b
80/// NonExpr   c       d
81pub fn fisher_exact_test<T>(
82    a: usize,
83    b: usize,
84    c: usize,
85    d: usize,
86    _alternative: Alternative,
87) -> TestResult<T>
88where
89    T: FloatOpsTS,
90{
91    // Implementation uses statrs Hypergeometric distribution
92    use statrs::distribution::Hypergeometric;
93    
94    let n1 = a + c; // Group 1 size
95    let n2 = b + d; // Group 2 size
96    let total_expr = a + b;
97    let total_cells = n1 + n2;
98
99    if total_cells == 0 {
100        return TestResult::new(T::zero(), T::one());
101    }
102
103    // Hypergeometric(total, success_in_total, draws)
104    // Here: N=total_cells, K=total_expr, n=n1 (draws from group 1)
105    match Hypergeometric::new(total_cells as u64, total_expr as u64, n1 as u64) {
106        Ok(hyper) => {
107            let p_val: f64 = match _alternative {
108                Alternative::Greater => 1.0 - hyper.cdf((a as u64).saturating_sub(1)),
109                Alternative::Less => hyper.cdf(a as u64),
110                Alternative::TwoSided => {
111                    let p_a = hyper.pmf(a as u64);
112                    let mut p_sum = 0.0;
113                    let upper_limit = std::cmp::min(n1, total_expr);
114                    for i in 0..=upper_limit {
115                        let p_i = hyper.pmf(i as u64);
116                        if p_i <= p_a + 1e-12 {
117                            p_sum += p_i;
118                        }
119                    }
120                    p_sum.min(1.0)
121                }
122            };
123            
124            let odds_ratio = if b * c == 0 {
125                if a * d > 0 { f64::INFINITY } else { 0.0 }
126            } else {
127                (a as f64 * d as f64) / (b as f64 * c as f64)
128            };
129            
130            TestResult::new(T::from(odds_ratio).unwrap(), T::from(p_val).unwrap())
131        }
132        Err(_) => TestResult::new(T::zero(), T::one()),
133    }
134}
135
136/// Perform Fisher's Exact Test across all genes in a sparse matrix.
137pub fn fisher_exact_sparse<T, N, I>(
138    matrix: SparseMatrixRef<T, N, I>,
139    group1_indices: &[usize],
140    group2_indices: &[usize],
141    alternative: Alternative,
142) -> anyhow::Result<Vec<TestResult<T>>>
143where
144    T: FloatOpsTS,
145    N: AsPrimitive<usize> + Send + Sync,
146    I: AsPrimitive<usize> + Send + Sync,
147{
148    let n_group1 = group1_indices.len();
149    let n_group2 = group2_indices.len();
150    
151    let mut cell_groups = vec![0u8; matrix.n_cols];
152    for &idx in group1_indices { if idx < cell_groups.len() { cell_groups[idx] = 1; } }
153    for &idx in group2_indices { if idx < cell_groups.len() { cell_groups[idx] = 2; } }
154
155    let results: Vec<_> = (0..matrix.n_rows)
156        .into_par_iter()
157        .map(|row| {
158            let start = matrix.maj_ind[row].as_();
159            let end = matrix.maj_ind[row + 1].as_();
160            
161            let mut a = 0; // Group 1 Expressed
162            let mut b = 0; // Group 2 Expressed
163
164            for i in start..end {
165                let col = matrix.min_ind[i].as_();
166                match cell_groups[col] {
167                    1 => a += 1,
168                    2 => b += 1,
169                    _ => {}
170                }
171            }
172            
173            let c = n_group1 - a;
174            let d = n_group2 - b;
175
176            fisher_exact_test(a, b, c, d, alternative)
177        })
178        .collect();
179
180    Ok(results)
181}