Skip to main content

single_statistics/testing/inference/
mod.rs

1use crate::testing::utils::{extract_unique_groups, get_group_indices, SparseMatrixRef};
2use crate::testing::{
3    Alternative, MultipleTestResults, TTestType, TestMethod, TestResult, correction,
4};
5use nalgebra_sparse::CsrMatrix;
6use single_utilities::traits::FloatOpsTS;
7use num_traits::AsPrimitive;
8
9pub mod discrete;
10pub mod parametric;
11pub mod nonparametric;
12
13/// Statistical testing methods for sparse matrices, particularly suited for single-cell data.
14///
15/// This trait extends sparse matrix types (like `CsrMatrix` or `SparseMatrixRef`) with 
16/// statistical testing capabilities.
17pub trait MatrixStatTests<T>
18where
19    T: FloatOpsTS,
20{
21    /// Perform t-tests comparing two groups of cells for all genes.
22    fn t_test(
23        &self,
24        group1_indices: &[usize],
25        group2_indices: &[usize],
26        test_type: TTestType,
27    ) -> anyhow::Result<Vec<TestResult<f64>>>;
28
29    /// Perform Mann-Whitney U tests comparing two groups of cells for all genes.
30    fn mann_whitney_test(
31        &self,
32        group1_indices: &[usize],
33        group2_indices: &[usize],
34        alternative: Alternative,
35    ) -> anyhow::Result<Vec<TestResult<f64>>>;
36
37    /// Perform Fisher's Exact tests comparing expression frequency between two groups.
38    fn fisher_exact_test(
39        &self,
40        group1_indices: &[usize],
41        group2_indices: &[usize],
42        alternative: Alternative,
43    ) -> anyhow::Result<Vec<TestResult<T>>>;
44
45    /// Comprehensive differential expression analysis with multiple testing correction.
46    fn differential_expression(
47        &self,
48        group_ids: &[usize],
49        test_method: TestMethod,
50    ) -> anyhow::Result<MultipleTestResults<f64>>;
51}
52
53impl<T> MatrixStatTests<T> for CsrMatrix<T>
54where
55    T: FloatOpsTS,
56    f64: std::convert::From<T>,
57{
58    fn t_test(
59        &self,
60        group1_indices: &[usize],
61        group2_indices: &[usize],
62        test_type: TTestType,
63    ) -> anyhow::Result<Vec<TestResult<f64>>> {
64        parametric::t_test_matrix_groups(self, group1_indices, group2_indices, test_type)
65    }
66
67    fn mann_whitney_test(
68        &self,
69        group1_indices: &[usize],
70        group2_indices: &[usize],
71        alternative: Alternative,
72    ) -> anyhow::Result<Vec<TestResult<f64>>> {
73        nonparametric::mann_whitney_matrix_groups(self, group1_indices, group2_indices, alternative)
74    }
75
76    fn fisher_exact_test(
77        &self,
78        group1_indices: &[usize],
79        group2_indices: &[usize],
80        alternative: Alternative,
81    ) -> anyhow::Result<Vec<TestResult<T>>> {
82        let smr = SparseMatrixRef {
83            maj_ind: self.row_offsets(),
84            min_ind: self.col_indices(),
85            val: self.values(),
86            n_rows: self.nrows(),
87            n_cols: self.ncols(),
88        };
89        discrete::fisher_exact_sparse(smr, group1_indices, group2_indices, alternative)
90    }
91
92    fn differential_expression(
93        &self,
94        group_ids: &[usize],
95        test_method: TestMethod,
96    ) -> anyhow::Result<MultipleTestResults<f64>> {
97        let smr = SparseMatrixRef {
98            maj_ind: self.row_offsets(),
99            min_ind: self.col_indices(),
100            val: self.values(),
101            n_rows: self.nrows(),
102            n_cols: self.ncols(),
103        };
104        smr.differential_expression(group_ids, test_method)
105    }
106}
107
108impl<'a, T, N, I> MatrixStatTests<T> for SparseMatrixRef<'a, T, N, I>
109where
110    T: FloatOpsTS,
111    N: AsPrimitive<usize> + Send + Sync,
112    I: AsPrimitive<usize> + Send + Sync,
113    f64: std::convert::From<T>,
114{
115    fn t_test(
116        &self,
117        group1_indices: &[usize],
118        group2_indices: &[usize],
119        test_type: TTestType,
120    ) -> anyhow::Result<Vec<TestResult<f64>>> {
121        parametric::t_test_sparse(*self, group1_indices, group2_indices, test_type)
122    }
123
124    fn mann_whitney_test(
125        &self,
126        group1_indices: &[usize],
127        group2_indices: &[usize],
128        alternative: Alternative,
129    ) -> anyhow::Result<Vec<TestResult<f64>>> {
130        nonparametric::mann_whitney_sparse(*self, group1_indices, group2_indices, alternative)
131    }
132
133    fn fisher_exact_test(
134        &self,
135        group1_indices: &[usize],
136        group2_indices: &[usize],
137        alternative: Alternative,
138    ) -> anyhow::Result<Vec<TestResult<T>>> {
139        discrete::fisher_exact_sparse(*self, group1_indices, group2_indices, alternative)
140    }
141
142    fn differential_expression(
143        &self,
144        group_ids: &[usize],
145        test_method: TestMethod,
146    ) -> anyhow::Result<MultipleTestResults<f64>> {
147        let unique_groups = extract_unique_groups(group_ids);
148        if unique_groups.len() != 2 {
149            return Err(anyhow::anyhow!(
150                "Currently only two-group comparisons are supported"
151            ));
152        }
153
154        let (group1_indices, group2_indices) = get_group_indices(group_ids, &unique_groups);
155
156        match test_method {
157            TestMethod::TTest(test_type) => {
158                let results = self.t_test(&group1_indices, &group2_indices, test_type)?;
159                let statistics: Vec<_> = results.iter().map(|r| r.statistic).collect();
160                let p_values: Vec<_> = results.iter().map(|r| r.p_value).collect();
161                let adjusted_p_values = correction::benjamini_hochberg_correction(&p_values)?;
162
163                let effect_sizes: Vec<f64> = results
164                    .iter()
165                    .map(|r| r.effect_size.unwrap_or(0.0))
166                    .collect();
167
168                Ok(MultipleTestResults::new(statistics, p_values)
169                    .with_adjusted_p_values(adjusted_p_values)
170                    .with_effect_sizes(effect_sizes)
171                    .with_global_metadata("test_type", "t_test"))
172            }
173
174            TestMethod::MannWhitney => {
175                let results = self.mann_whitney_test(
176                    &group1_indices,
177                    &group2_indices,
178                    Alternative::TwoSided,
179                )?;
180                let statistics: Vec<_> = results.iter().map(|r| r.statistic).collect();
181                let p_values: Vec<_> = results.iter().map(|r| r.p_value).collect();
182                let adjusted_p_values = correction::benjamini_hochberg_correction(&p_values)?;
183
184                Ok(MultipleTestResults::new(statistics, p_values)
185                    .with_adjusted_p_values(adjusted_p_values)
186                    .with_global_metadata("test_type", "mann_whitney"))
187            }
188
189            TestMethod::FisherExact => {
190                let results = self.fisher_exact_test(
191                    &group1_indices,
192                    &group2_indices,
193                    Alternative::TwoSided,
194                )?;
195                let statistics: Vec<_> = results.iter().map(|r| r.statistic.to_f64().unwrap()).collect();
196                let p_values: Vec<_> = results.iter().map(|r| r.p_value.to_f64().unwrap()).collect();
197                let adjusted_p_values = correction::benjamini_hochberg_correction(&p_values)?;
198
199                Ok(MultipleTestResults::new(statistics, p_values)
200                    .with_adjusted_p_values(adjusted_p_values)
201                    .with_global_metadata("test_type", "fisher_exact"))
202            }
203            _ => Err(anyhow::anyhow!("Test method not implemented yet")),
204        }
205    }
206}