Skip to main content

single_statistics/testing/inference/
nonparametric.rs

1//! Non-parametric statistical tests for single-cell data analysis.
2//!
3//! This module implements non-parametric statistical tests that make fewer assumptions about
4//! data distribution. These tests are particularly useful for single-cell data which often
5//! exhibits non-normal distributions, high sparsity, and outliers.
6//!
7//! The primary test implemented is the Mann-Whitney U test (also known as the Wilcoxon 
8//! rank-sum test), which compares the distributions of two groups without assuming normality.
9
10use std::{cmp::Ordering, f64};
11
12use nalgebra_sparse::CsrMatrix;
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14use single_utilities::traits::FloatOpsTS;
15use statrs::distribution::{ContinuousCDF, Normal};
16
17use crate::testing::{Alternative, TestResult};
18use crate::testing::utils::SparseMatrixRef;
19use num_traits::AsPrimitive;
20
21/// Perform Mann-Whitney U tests on all genes comparing two groups of cells.
22pub fn mann_whitney_matrix_groups<T>(
23    matrix: &CsrMatrix<T>,
24    group1_indices: &[usize],
25    group2_indices: &[usize],
26    alternative: Alternative,
27) -> anyhow::Result<Vec<TestResult<f64>>>
28where
29    T: FloatOpsTS,
30    f64: std::convert::From<T>,
31{
32    let smr = SparseMatrixRef {
33        maj_ind: matrix.row_offsets(),
34        min_ind: matrix.col_indices(),
35        val: matrix.values(),
36        n_rows: matrix.nrows(),
37        n_cols: matrix.ncols(),
38    };
39    mann_whitney_sparse(smr, group1_indices, group2_indices, alternative)
40}
41
42/// Perform Mann-Whitney U tests on a sparse matrix represented by raw components.
43pub fn mann_whitney_sparse<T, N, I>(
44    matrix: SparseMatrixRef<T, N, I>,
45    group1_indices: &[usize],
46    group2_indices: &[usize],
47    alternative: Alternative,
48) -> anyhow::Result<Vec<TestResult<f64>>>
49where
50    T: FloatOpsTS,
51    N: AsPrimitive<usize> + Send + Sync,
52    I: AsPrimitive<usize> + Send + Sync,
53    f64: std::convert::From<T>,
54{
55    if group1_indices.is_empty() || group2_indices.is_empty() {
56        return Err(anyhow::anyhow!(
57            "Single-Statistics | Group indices cannot be empty. Error code: SS-NP-001"
58        ));
59    }
60
61    let nrows = matrix.n_rows;
62    let n_group1 = group1_indices.len();
63    let n_group2 = group2_indices.len();
64
65    // Mapping from column index to group ID (0 for none, 1 for group1, 2 for group2)
66    let mut cell_groups = vec![0u8; matrix.n_cols];
67    for &idx in group1_indices {
68        if idx < cell_groups.len() { cell_groups[idx] = 1; }
69    }
70    for &idx in group2_indices {
71        if idx < cell_groups.len() { cell_groups[idx] = 2; }
72    }
73
74    let results: Vec<_> = (0..nrows)
75        .into_par_iter()
76        .map(|row| {
77            let start = matrix.maj_ind[row].as_();
78            let end = matrix.maj_ind[row + 1].as_();
79            
80            let mut x_nonzero = Vec::new();
81            let mut y_nonzero = Vec::new();
82            let mut g1_nz_count = 0;
83            let mut g2_nz_count = 0;
84
85            for i in start..end {
86                let col = matrix.min_ind[i].as_();
87                let val = f64::from(matrix.val[i]);
88                
89                match cell_groups[col] {
90                    1 => {
91                        if val != 0.0 { x_nonzero.push(val); }
92                        g1_nz_count += 1;
93                    },
94                    2 => {
95                        if val != 0.0 { y_nonzero.push(val); }
96                        g2_nz_count += 1;
97                    },
98                    _ => {}
99                }
100            }
101            
102            let x_zeros = n_group1 - g1_nz_count;
103            let y_zeros = n_group2 - g2_nz_count;
104
105            mann_whitney_from_sparse_parts(x_nonzero, y_nonzero, x_zeros, y_zeros, alternative)
106        })
107        .collect();
108
109    Ok(results)
110}
111
112/// Core MW-U logic optimized for sparse scRNA-seq data (many zeros).
113fn mann_whitney_from_sparse_parts(
114    x_nonzero: Vec<f64>,
115    y_nonzero: Vec<f64>,
116    x_zeros: usize,
117    y_zeros: usize,
118    alternative: Alternative,
119) -> TestResult<f64> {
120    let nx = x_zeros + x_nonzero.len();
121    let ny = y_zeros + y_nonzero.len();
122
123    if nx == 0 || ny == 0 {
124        return TestResult::new(f64::NAN, 1.0);
125    }
126
127    let mut combined_nz: Vec<(f64, u8)> = Vec::with_capacity(x_nonzero.len() + y_nonzero.len());
128    for v in x_nonzero { combined_nz.push((v, 0)); }
129    for v in y_nonzero { combined_nz.push((v, 1)); }
130    combined_nz.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
131
132    let n_total = (nx + ny) as f64;
133    let (rank_sum_x, tie_correction) = {
134        let n_zeros = (x_zeros + y_zeros) as f64;
135        let mut rs_x = 0.0;
136        let mut t_corr = 0.0;
137        let mut current_rank = 1.0;
138
139        // Handle Zeros
140        if n_zeros > 0.0 {
141            let avg_rank_zeros = (n_zeros + 1.0) / 2.0;
142            rs_x += (x_zeros as f64) * avg_rank_zeros;
143            t_corr += n_zeros.powi(3) - n_zeros;
144            current_rank += n_zeros;
145        }
146
147        // Handle Non-zeros
148        let mut i = 0;
149        while i < combined_nz.len() {
150            let val = combined_nz[i].0;
151            let start = i;
152            while i < combined_nz.len() && combined_nz[i].0 == val { i += 1; }
153            let count = (i - start) as f64;
154            let avg_rank = current_rank + (count - 1.0) / 2.0;
155            
156            for j in start..i {
157                if combined_nz[j].1 == 0 { rs_x += avg_rank; }
158            }
159            if count > 1.0 {
160                t_corr += count.powi(3) - count;
161            }
162            current_rank += count;
163        }
164        (rs_x, t_corr)
165    };
166
167    let nx_f = nx as f64;
168    let ny_f = ny as f64;
169    let u_x = rank_sum_x - (nx_f * (nx_f + 1.0)) / 2.0;
170    let u_y = (nx_f * ny_f) - u_x;
171    let mean_u = nx_f * ny_f / 2.0;
172    
173    let var_u = (nx_f * ny_f / (n_total * (n_total - 1.0))) * 
174                ((n_total.powi(3) - n_total - tie_correction) / 12.0);
175
176    let (u_stat, z) = match alternative {
177        Alternative::TwoSided => {
178            let u = u_x.min(u_y);
179            let z_score = if var_u > 0.0 {
180                ((u - mean_u).abs() - 0.5).max(0.0) / var_u.sqrt()
181            } else { 0.0 };
182            (u, z_score)
183        },
184        Alternative::Greater => {
185            let z_score = if var_u > 0.0 {
186                (u_x - mean_u - 0.5) / var_u.sqrt()
187            } else { 0.0 };
188            (u_x, z_score)
189        },
190        Alternative::Less => {
191            let z_score = if var_u > 0.0 {
192                (u_x - mean_u + 0.5) / var_u.sqrt()
193            } else { 0.0 };
194            (u_x, z_score)
195        }
196    };
197
198    let p = calculate_p_value(z, alternative, nx_f, ny_f);
199    TestResult::new(u_stat, p)
200        .with_metadata("z_score", z)
201        .with_metadata("var_u", var_u)
202        .with_metadata("tie_correction", tie_correction)
203}
204
205/// Public API for two samples (dense).
206pub fn mann_whitney_optimized(x: &[f64], y: &[f64], alternative: Alternative) -> TestResult<f64> {
207    let mut x_nz = Vec::new();
208    let mut x_z = 0;
209    for &v in x { if v.is_finite() { if v == 0.0 { x_z += 1; } else { x_nz.push(v); } } }
210
211    let mut y_nz = Vec::new();
212    let mut y_z = 0;
213    for &v in y { if v.is_finite() { if v == 0.0 { y_z += 1; } else { y_nz.push(v); } } }
214
215    mann_whitney_from_sparse_parts(x_nz, y_nz, x_z, y_z, alternative)
216}
217
218#[inline]
219fn calculate_p_value(z: f64, alternative: Alternative, nx: f64, ny: f64) -> f64 {
220    if nx < 3.0 || ny < 3.0 { return 1.0; }
221    if !z.is_finite() { return 1.0; }
222
223    let normal = Normal::new(0.0, 1.0).unwrap();
224    match alternative {
225        Alternative::TwoSided => 2.0 * (1.0 - normal.cdf(z.abs())),
226        Alternative::Greater => 1.0 - normal.cdf(z),
227        Alternative::Less => normal.cdf(z),
228    }
229}