single_statistics/testing/inference/
parametric.rs1use crate::testing::utils::{accumulate_gene_statistics_two_groups_raw, SparseMatrixRef};
8use crate::testing::{TTestType, TestResult};
9use nalgebra_sparse::CsrMatrix;
10use single_utilities::traits::{FloatOps, FloatOpsTS};
11use statrs::distribution::{ContinuousCDF, StudentsT};
12use num_traits::AsPrimitive;
13
14pub fn t_test_matrix_groups<T>(
30 matrix: &CsrMatrix<T>,
31 group1_indices: &[usize],
32 group2_indices: &[usize],
33 test_type: TTestType,
34) -> anyhow::Result<Vec<TestResult<f64>>>
35where
36 T: FloatOpsTS,
37{
38 let smr = SparseMatrixRef {
39 maj_ind: matrix.row_offsets(),
40 min_ind: matrix.col_indices(),
41 val: matrix.values(),
42 n_rows: matrix.nrows(),
43 n_cols: matrix.ncols(),
44 };
45 t_test_sparse(smr, group1_indices, group2_indices, test_type)
46}
47
48pub fn t_test_sparse<T, N, I>(
52 matrix: SparseMatrixRef<T, N, I>,
53 group1_indices: &[usize],
54 group2_indices: &[usize],
55 test_type: TTestType,
56) -> anyhow::Result<Vec<TestResult<f64>>>
57where
58 T: FloatOpsTS,
59 N: AsPrimitive<usize> + Send + Sync,
60 I: AsPrimitive<usize> + Send + Sync,
61{
62 if group1_indices.is_empty() || group2_indices.is_empty() {
63 return Err(anyhow::anyhow!("Group indices cannot be empty"));
64 }
65
66 let n_genes = matrix.n_rows;
67 let group1_size = T::from(group1_indices.len()).unwrap();
68 let group2_size = T::from(group2_indices.len()).unwrap();
69
70 let (group1_sums, group1_sum_squares, group2_sums, group2_sum_squares) =
71 accumulate_gene_statistics_two_groups_raw(matrix, group1_indices, group2_indices, n_genes)?;
72
73 let results: Vec<TestResult<f64>> = (0..n_genes)
74 .map(|gene_idx| {
75 fast_t_test_from_sums(
76 group1_sums[gene_idx].to_f64().unwrap(),
77 group1_sum_squares[gene_idx].to_f64().unwrap(),
78 group1_size.to_f64().unwrap(),
79 group2_sums[gene_idx].to_f64().unwrap(),
80 group2_sum_squares[gene_idx].to_f64().unwrap(),
81 group2_size.to_f64().unwrap(),
82 test_type,
83 )
84 })
85 .collect();
86
87 Ok(results)
88}
89
90pub fn t_test<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
105where
106 T: FloatOps,
107{
108 let nx = x.len();
109 let ny = y.len();
110
111 if nx < 2 || ny < 2 {
112 return TestResult::new(0.0, 1.0);
113 }
114
115 if nx + ny < 1000 {
117 t_test_small_optimized(x, y, test_type)
119 } else {
120 t_test_large(x, y, test_type)
122 }
123}
124
125#[inline]
126fn t_test_small_optimized<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
127where
128 T: FloatOps,
129{
130 let mut sum_x = T::zero();
132 let mut sum_sq_x = T::zero();
133 for &val in x {
134 sum_x += val;
135 sum_sq_x += val * val;
136 }
137
138 let mut sum_y = T::zero();
139 let mut sum_sq_y = T::zero();
140 for &val in y {
141 sum_y += val;
142 sum_sq_y += val * val;
143 }
144
145 let nx_f = T::from(x.len()).unwrap();
146 let ny_f = T::from(y.len()).unwrap();
147
148 fast_t_test_from_sums(
149 sum_x.to_f64().unwrap(),
150 sum_sq_x.to_f64().unwrap(),
151 nx_f.to_f64().unwrap(),
152 sum_y.to_f64().unwrap(),
153 sum_sq_y.to_f64().unwrap(),
154 ny_f.to_f64().unwrap(),
155 test_type
156 )
157}
158
159#[inline]
160fn t_test_large<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
161where
162 T: FloatOps,
163{
164 const CHUNK_SIZE: usize = 256;
166
167 let mut sum_x = T::zero();
168 let mut sum_sq_x = T::zero();
169
170 for chunk in x.chunks(CHUNK_SIZE) {
171 for &val in chunk {
172 sum_x += val;
173 sum_sq_x += val * val;
174 }
175 }
176
177 let mut sum_y = T::zero();
178 let mut sum_sq_y = T::zero();
179
180 for chunk in y.chunks(CHUNK_SIZE) {
181 for &val in chunk {
182 sum_y += val;
183 sum_sq_y += val * val;
184 }
185 }
186
187 let nx_f = T::from(x.len()).unwrap();
188 let ny_f = T::from(y.len()).unwrap();
189
190 fast_t_test_from_sums(
191 sum_x.to_f64().unwrap(),
192 sum_sq_x.to_f64().unwrap(),
193 nx_f.to_f64().unwrap(),
194 sum_y.to_f64().unwrap(),
195 sum_sq_y.to_f64().unwrap(),
196 ny_f.to_f64().unwrap(),
197 test_type
198 )
199}
200
201pub fn fast_t_test_from_sums(
218 sum1: f64,
219 sum_sq1: f64,
220 n1: f64,
221 sum2: f64,
222 sum_sq2: f64,
223 n2: f64,
224 test_type: TTestType,
225) -> TestResult<f64>
226{
227 if n1 < 2.0 || n2 < 2.0 {
229 return TestResult::new(0.0, 1.0);
230 }
231
232 let mean1 = sum1 / n1;
234 let mean2 = sum2 / n2;
235
236 let var1 = (sum_sq1 - sum1 * sum1 / n1) / (n1 - 1.0);
238 let var2 = (sum_sq2 - sum2 * sum2 / n2) / (n2 - 1.0);
239
240 let mean_diff = mean1 - mean2;
241
242 let (t_stat, df) = match test_type {
243 TTestType::Student => {
244 let pooled_var = ((n1 - 1.0) * var1 + (n2 - 1.0) * var2) / (n1 + n2 - 2.0);
246 let std_err = (pooled_var * (1.0 / n1 + 1.0 / n2)).sqrt();
247 (mean_diff / std_err, n1 + n2 - 2.0)
248 }
249 TTestType::Welch => {
250 let term1 = var1 / n1;
252 let term2 = var2 / n2;
253 let combined_var = term1 + term2;
254 let std_err = combined_var.sqrt();
255 let t = mean_diff / std_err;
256
257 let df = combined_var * combined_var /
259 (term1 * term1 / (n1 - 1.0) + term2 * term2 / (n2 - 1.0));
260 (t, df)
261 }
262 };
263
264 let p_value = fast_t_test_p_value(t_stat, df);
265 TestResult::new(t_stat, p_value)
266}
267
268#[inline]
269fn fast_t_test_p_value(t_stat: f64, df: f64) -> f64
270{
271 if !t_stat.is_finite() {
273 return if t_stat.is_infinite() { 0.0 } else { 1.0 };
274 }
275
276 if df <= 0.0 || !df.is_finite() {
277 return 1.0;
278 }
279
280 let abs_t = t_stat.abs();
281
282 if abs_t < 0.001 {
284 return 1.0; }
286
287 if abs_t > 37.0 {
289 let log_p = log_normal_tail_probability(abs_t);
290 return 2.0 * log_p.exp();
291 }
292
293 if df > 100.0 {
295 return 2.0 * high_precision_normal_cdf_complement(abs_t);
296 }
297
298 match StudentsT::new(0.0, 1.0, df) {
300 Ok(t_dist) => {
301 let cdf_val = t_dist.cdf(abs_t);
302 2.0 * (1.0 - cdf_val)
303 }
304 Err(_) => 1.0,
305 }
306}
307
308#[inline]
310fn log_normal_tail_probability(x: f64) -> f64 {
311 if x < 0.0 {
312 return 0.0;
313 }
314
315 if x > 8.0 {
316 let x_sq = x * x;
317 return -0.5 * x_sq - (x * (2.0 * std::f64::consts::PI).sqrt()).ln();
318 }
319
320 let z = x / (2.0_f64).sqrt();
321 log_erfc(z) - (2.0_f64).ln()
322}
323
324#[inline]
326fn log_erfc(x: f64) -> f64 {
327 if x < 0.0 {
328 return 0.0;
329 }
330
331 if x > 26.0 {
332 let x_sq = x * x;
333 return -x_sq - 0.5 * (std::f64::consts::PI).ln() - x.ln();
334 }
335
336 continued_fraction_log_erfc(x)
337}
338
339#[inline]
341fn continued_fraction_log_erfc(x: f64) -> f64 {
342 if x < 2.0 {
343 let erf_val = erf_series(x);
344 return (1.0 - erf_val).ln();
345 }
346
347 let x_sq = x * x;
348 let mut a = 1.0;
349 let mut b = 2.0 * x_sq;
350 let mut result = a / b;
351
352 for n in 1..50 {
353 a = -(2 * n - 1) as f64;
354 b = 2.0 * x_sq + a / result;
355 let new_result = a / b;
356
357 if (result - new_result).abs() < 1e-15 {
358 break;
359 }
360 result = new_result;
361 }
362
363 -x_sq + (result / (x * (std::f64::consts::PI).sqrt())).ln()
364}
365
366#[inline]
368fn erf_series(x: f64) -> f64 {
369 let x_sq = x * x;
370 let mut term = x;
371 let mut result = term;
372
373 for n in 1..100 {
374 term *= -x_sq / (n as f64);
375 let new_term = term / (2.0 * n as f64 + 1.0);
376 result += new_term;
377
378 if new_term.abs() < 1e-16 {
379 break;
380 }
381 }
382
383 result * 2.0 / (std::f64::consts::PI).sqrt()
384}
385
386#[inline]
388fn high_precision_normal_cdf_complement(x: f64) -> f64 {
389 if x < 0.0 {
390 return 1.0 - high_precision_normal_cdf_complement(-x);
391 }
392
393 if x > 37.0 {
394 let log_p = log_normal_tail_probability(x);
395 return log_p.exp();
396 }
397
398 0.5 * erfc_high_precision(x / (2.0_f64).sqrt())
399}
400
401#[inline]
403fn erfc_high_precision(x: f64) -> f64 {
404 if x < 0.0 {
405 return 2.0 - erfc_high_precision(-x);
406 }
407
408 if x > 26.0 {
409 return 0.0;
410 }
411
412 if x < 2.0 {
413 return 1.0 - erf_series(x);
414 }
415 chebyshev_erfc(x)
416}
417
418#[inline]
420fn chebyshev_erfc(x: f64) -> f64 {
421 let a1 = 0.0705230784;
422 let a2 = 0.0422820123;
423 let a3 = 0.0092705272;
424 let a4 = 0.0001520143;
425 let a5 = 0.0002765672;
426 let a6 = 0.0000430638;
427
428 let t = 1.0 / (1.0 + 0.3275911 * x);
429 let poly = t * (a1 + t * (a2 + t * (a3 + t * (a4 + t * (a5 + t * a6)))));
430
431 poly * (-x * x).exp()
432}