single_statistics/testing/inference/
nonparametric.rs1use std::{cmp::Ordering, f64};
11
12use nalgebra_sparse::CsrMatrix;
13use rayon::iter::{IntoParallelIterator, ParallelIterator};
14use single_utilities::traits::FloatOpsTS;
15use statrs::distribution::{ContinuousCDF, Normal};
16
17use crate::testing::{Alternative, TestResult};
18use crate::testing::utils::SparseMatrixRef;
19use num_traits::AsPrimitive;
20
21pub fn mann_whitney_matrix_groups<T>(
23 matrix: &CsrMatrix<T>,
24 group1_indices: &[usize],
25 group2_indices: &[usize],
26 alternative: Alternative,
27) -> anyhow::Result<Vec<TestResult<f64>>>
28where
29 T: FloatOpsTS,
30 f64: std::convert::From<T>,
31{
32 let smr = SparseMatrixRef {
33 maj_ind: matrix.row_offsets(),
34 min_ind: matrix.col_indices(),
35 val: matrix.values(),
36 n_rows: matrix.nrows(),
37 n_cols: matrix.ncols(),
38 };
39 mann_whitney_sparse(smr, group1_indices, group2_indices, alternative)
40}
41
42pub fn mann_whitney_sparse<T, N, I>(
44 matrix: SparseMatrixRef<T, N, I>,
45 group1_indices: &[usize],
46 group2_indices: &[usize],
47 alternative: Alternative,
48) -> anyhow::Result<Vec<TestResult<f64>>>
49where
50 T: FloatOpsTS,
51 N: AsPrimitive<usize> + Send + Sync,
52 I: AsPrimitive<usize> + Send + Sync,
53 f64: std::convert::From<T>,
54{
55 if group1_indices.is_empty() || group2_indices.is_empty() {
56 return Err(anyhow::anyhow!(
57 "Single-Statistics | Group indices cannot be empty. Error code: SS-NP-001"
58 ));
59 }
60
61 let nrows = matrix.n_rows;
62 let n_group1 = group1_indices.len();
63 let n_group2 = group2_indices.len();
64
65 let mut cell_groups = vec![0u8; matrix.n_cols];
67 for &idx in group1_indices {
68 if idx < cell_groups.len() { cell_groups[idx] = 1; }
69 }
70 for &idx in group2_indices {
71 if idx < cell_groups.len() { cell_groups[idx] = 2; }
72 }
73
74 let results: Vec<_> = (0..nrows)
75 .into_par_iter()
76 .map(|row| {
77 let start = matrix.maj_ind[row].as_();
78 let end = matrix.maj_ind[row + 1].as_();
79
80 let mut x_nonzero = Vec::new();
81 let mut y_nonzero = Vec::new();
82 let mut g1_nz_count = 0;
83 let mut g2_nz_count = 0;
84
85 for i in start..end {
86 let col = matrix.min_ind[i].as_();
87 let val = f64::from(matrix.val[i]);
88
89 match cell_groups[col] {
90 1 => {
91 if val != 0.0 { x_nonzero.push(val); }
92 g1_nz_count += 1;
93 },
94 2 => {
95 if val != 0.0 { y_nonzero.push(val); }
96 g2_nz_count += 1;
97 },
98 _ => {}
99 }
100 }
101
102 let x_zeros = n_group1 - g1_nz_count;
103 let y_zeros = n_group2 - g2_nz_count;
104
105 mann_whitney_from_sparse_parts(x_nonzero, y_nonzero, x_zeros, y_zeros, alternative)
106 })
107 .collect();
108
109 Ok(results)
110}
111
112fn mann_whitney_from_sparse_parts(
114 x_nonzero: Vec<f64>,
115 y_nonzero: Vec<f64>,
116 x_zeros: usize,
117 y_zeros: usize,
118 alternative: Alternative,
119) -> TestResult<f64> {
120 let nx = x_zeros + x_nonzero.len();
121 let ny = y_zeros + y_nonzero.len();
122
123 if nx == 0 || ny == 0 {
124 return TestResult::new(f64::NAN, 1.0);
125 }
126
127 let mut combined_nz: Vec<(f64, u8)> = Vec::with_capacity(x_nonzero.len() + y_nonzero.len());
128 for v in x_nonzero { combined_nz.push((v, 0)); }
129 for v in y_nonzero { combined_nz.push((v, 1)); }
130 combined_nz.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
131
132 let n_total = (nx + ny) as f64;
133 let (rank_sum_x, tie_correction) = {
134 let n_zeros = (x_zeros + y_zeros) as f64;
135 let mut rs_x = 0.0;
136 let mut t_corr = 0.0;
137 let mut current_rank = 1.0;
138
139 if n_zeros > 0.0 {
141 let avg_rank_zeros = (n_zeros + 1.0) / 2.0;
142 rs_x += (x_zeros as f64) * avg_rank_zeros;
143 t_corr += n_zeros.powi(3) - n_zeros;
144 current_rank += n_zeros;
145 }
146
147 let mut i = 0;
149 while i < combined_nz.len() {
150 let val = combined_nz[i].0;
151 let start = i;
152 while i < combined_nz.len() && combined_nz[i].0 == val { i += 1; }
153 let count = (i - start) as f64;
154 let avg_rank = current_rank + (count - 1.0) / 2.0;
155
156 for j in start..i {
157 if combined_nz[j].1 == 0 { rs_x += avg_rank; }
158 }
159 if count > 1.0 {
160 t_corr += count.powi(3) - count;
161 }
162 current_rank += count;
163 }
164 (rs_x, t_corr)
165 };
166
167 let nx_f = nx as f64;
168 let ny_f = ny as f64;
169 let u_x = rank_sum_x - (nx_f * (nx_f + 1.0)) / 2.0;
170 let u_y = (nx_f * ny_f) - u_x;
171 let mean_u = nx_f * ny_f / 2.0;
172
173 let var_u = (nx_f * ny_f / (n_total * (n_total - 1.0))) *
174 ((n_total.powi(3) - n_total - tie_correction) / 12.0);
175
176 let (u_stat, z) = match alternative {
177 Alternative::TwoSided => {
178 let u = u_x.min(u_y);
179 let z_score = if var_u > 0.0 {
180 ((u - mean_u).abs() - 0.5).max(0.0) / var_u.sqrt()
181 } else { 0.0 };
182 (u, z_score)
183 },
184 Alternative::Greater => {
185 let z_score = if var_u > 0.0 {
186 (u_x - mean_u - 0.5) / var_u.sqrt()
187 } else { 0.0 };
188 (u_x, z_score)
189 },
190 Alternative::Less => {
191 let z_score = if var_u > 0.0 {
192 (u_x - mean_u + 0.5) / var_u.sqrt()
193 } else { 0.0 };
194 (u_x, z_score)
195 }
196 };
197
198 let p = calculate_p_value(z, alternative, nx_f, ny_f);
199 TestResult::new(u_stat, p)
200 .with_metadata("z_score", z)
201 .with_metadata("var_u", var_u)
202 .with_metadata("tie_correction", tie_correction)
203}
204
205pub fn mann_whitney_optimized(x: &[f64], y: &[f64], alternative: Alternative) -> TestResult<f64> {
207 let mut x_nz = Vec::new();
208 let mut x_z = 0;
209 for &v in x { if v.is_finite() { if v == 0.0 { x_z += 1; } else { x_nz.push(v); } } }
210
211 let mut y_nz = Vec::new();
212 let mut y_z = 0;
213 for &v in y { if v.is_finite() { if v == 0.0 { y_z += 1; } else { y_nz.push(v); } } }
214
215 mann_whitney_from_sparse_parts(x_nz, y_nz, x_z, y_z, alternative)
216}
217
218#[inline]
219fn calculate_p_value(z: f64, alternative: Alternative, nx: f64, ny: f64) -> f64 {
220 if nx < 3.0 || ny < 3.0 { return 1.0; }
221 if !z.is_finite() { return 1.0; }
222
223 let normal = Normal::new(0.0, 1.0).unwrap();
224 match alternative {
225 Alternative::TwoSided => 2.0 * (1.0 - normal.cdf(z.abs())),
226 Alternative::Greater => 1.0 - normal.cdf(z),
227 Alternative::Less => normal.cdf(z),
228 }
229}