Skip to main content

single_statistics/testing/inference/
parametric.rs

1//! Parametric statistical tests for single-cell data analysis.
2//!
3//! This module implements parametric statistical tests, primarily t-tests, optimized for
4//! sparse single-cell expression matrices. The implementations are designed for efficiency
5//! when testing thousands of genes across different cell groups.
6
7use crate::testing::utils::{accumulate_gene_statistics_two_groups_raw, SparseMatrixRef};
8use crate::testing::{TTestType, TestResult};
9use nalgebra_sparse::CsrMatrix;
10use single_utilities::traits::{FloatOps, FloatOpsTS};
11use statrs::distribution::{ContinuousCDF, StudentsT};
12use num_traits::AsPrimitive;
13
14/// Perform t-tests on all genes comparing two groups of cells.
15///
16/// This is an optimized implementation that efficiently computes summary statistics
17/// for sparse matrices and performs t-tests for each gene.
18///
19/// # Arguments
20///
21/// * `matrix` - Sparse expression matrix (cells × genes)
22/// * `group1_indices` - Column indices for the first group of cells
23/// * `group2_indices` - Column indices for the second group of cells
24/// * `test_type` - Type of t-test to perform (Student's or Welch's)
25///
26/// # Returns
27///
28/// Vector of `TestResult` objects, one per gene, containing t-statistics and p-values.
29pub fn t_test_matrix_groups<T>(
30    matrix: &CsrMatrix<T>,
31    group1_indices: &[usize],
32    group2_indices: &[usize],
33    test_type: TTestType,
34) -> anyhow::Result<Vec<TestResult<f64>>>
35where
36    T: FloatOpsTS,
37{
38    let smr = SparseMatrixRef {
39        maj_ind: matrix.row_offsets(),
40        min_ind: matrix.col_indices(),
41        val: matrix.values(),
42        n_rows: matrix.nrows(),
43        n_cols: matrix.ncols(),
44    };
45    t_test_sparse(smr, group1_indices, group2_indices, test_type)
46}
47
48/// Perform t-tests on a sparse matrix represented by raw components.
49/// 
50/// This version is agnostic of the matrix container and can be used with raw vectors.
51pub fn t_test_sparse<T, N, I>(
52    matrix: SparseMatrixRef<T, N, I>,
53    group1_indices: &[usize],
54    group2_indices: &[usize],
55    test_type: TTestType,
56) -> anyhow::Result<Vec<TestResult<f64>>>
57where
58    T: FloatOpsTS,
59    N: AsPrimitive<usize> + Send + Sync,
60    I: AsPrimitive<usize> + Send + Sync,
61{
62    if group1_indices.is_empty() || group2_indices.is_empty() {
63        return Err(anyhow::anyhow!("Group indices cannot be empty"));
64    }
65
66    let n_genes = matrix.n_rows;
67    let group1_size = T::from(group1_indices.len()).unwrap();
68    let group2_size = T::from(group2_indices.len()).unwrap();
69
70    let (group1_sums, group1_sum_squares, group2_sums, group2_sum_squares) =
71        accumulate_gene_statistics_two_groups_raw(matrix, group1_indices, group2_indices, n_genes)?;
72
73    let results: Vec<TestResult<f64>> = (0..n_genes)
74        .map(|gene_idx| {
75            fast_t_test_from_sums(
76                group1_sums[gene_idx].to_f64().unwrap(),
77                group1_sum_squares[gene_idx].to_f64().unwrap(),
78                group1_size.to_f64().unwrap(),
79                group2_sums[gene_idx].to_f64().unwrap(),
80                group2_sum_squares[gene_idx].to_f64().unwrap(),
81                group2_size.to_f64().unwrap(),
82                test_type,
83            )
84        })
85        .collect();
86
87    Ok(results)
88}
89
90/// Perform a t-test comparing two samples.
91///
92/// This function performs either Student's t-test (assuming equal variances) or
93/// Welch's t-test (allowing unequal variances) on two samples.
94///
95/// # Arguments
96///
97/// * `x` - First sample
98/// * `y` - Second sample  
99/// * `test_type` - Type of t-test to perform
100///
101/// # Returns
102///
103/// `TestResult` containing the t-statistic and p-value.
104pub fn t_test<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
105where
106    T: FloatOps,
107{
108    let nx = x.len();
109    let ny = y.len();
110
111    if nx < 2 || ny < 2 {
112        return TestResult::new(0.0, 1.0);
113    }
114
115    // Branch optimization: use different strategies based on size
116    if nx + ny < 1000 {
117        // For small datasets, optimize for simplicity and cache locality
118        t_test_small_optimized(x, y, test_type)
119    } else {
120        // For larger datasets, use the original approach
121        t_test_large(x, y, test_type)
122    }
123}
124
125#[inline]
126fn t_test_small_optimized<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
127where
128    T: FloatOps,
129{
130    // Optimized single-pass computation with better locality
131    let mut sum_x = T::zero();
132    let mut sum_sq_x = T::zero();
133    for &val in x {
134        sum_x += val;
135        sum_sq_x += val * val;
136    }
137
138    let mut sum_y = T::zero();
139    let mut sum_sq_y = T::zero();
140    for &val in y {
141        sum_y += val;
142        sum_sq_y += val * val;
143    }
144
145    let nx_f = T::from(x.len()).unwrap();
146    let ny_f = T::from(y.len()).unwrap();
147
148    fast_t_test_from_sums(
149        sum_x.to_f64().unwrap(), 
150        sum_sq_x.to_f64().unwrap(), 
151        nx_f.to_f64().unwrap(), 
152        sum_y.to_f64().unwrap(), 
153        sum_sq_y.to_f64().unwrap(), 
154        ny_f.to_f64().unwrap(), 
155        test_type
156    )
157}
158
159#[inline]
160fn t_test_large<T>(x: &[T], y: &[T], test_type: TTestType) -> TestResult<f64>
161where
162    T: FloatOps,
163{
164    // For larger datasets, use chunked processing to improve cache efficiency
165    const CHUNK_SIZE: usize = 256;
166    
167    let mut sum_x = T::zero();
168    let mut sum_sq_x = T::zero();
169    
170    for chunk in x.chunks(CHUNK_SIZE) {
171        for &val in chunk {
172            sum_x += val;
173            sum_sq_x += val * val;
174        }
175    }
176
177    let mut sum_y = T::zero();
178    let mut sum_sq_y = T::zero();
179    
180    for chunk in y.chunks(CHUNK_SIZE) {
181        for &val in chunk {
182            sum_y += val;
183            sum_sq_y += val * val;
184        }
185    }
186
187    let nx_f = T::from(x.len()).unwrap();
188    let ny_f = T::from(y.len()).unwrap();
189
190    fast_t_test_from_sums(
191        sum_x.to_f64().unwrap(), 
192        sum_sq_x.to_f64().unwrap(), 
193        nx_f.to_f64().unwrap(), 
194        sum_y.to_f64().unwrap(), 
195        sum_sq_y.to_f64().unwrap(), 
196        ny_f.to_f64().unwrap(), 
197        test_type
198    )
199}
200
201/// Perform a t-test using precomputed summary statistics.
202///
203/// This is an optimized function that computes t-tests directly from sum and sum-of-squares,
204/// avoiding the need to store or iterate through the original data. Particularly useful for
205/// sparse matrix operations where computing these statistics is done efficiently during
206/// matrix traversal.
207///
208/// # Arguments
209///
210/// * `sum1`, `sum_sq1`, `n1` - Sum, sum of squares, and count for group 1
211/// * `sum2`, `sum_sq2`, `n2` - Sum, sum of squares, and count for group 2
212/// * `test_type` - Type of t-test to perform (Student's or Welch's)
213///
214/// # Returns
215///
216/// `TestResult` containing the t-statistic and p-value.
217pub fn fast_t_test_from_sums(
218    sum1: f64,
219    sum_sq1: f64,
220    n1: f64,
221    sum2: f64,
222    sum_sq2: f64,
223    n2: f64,
224    test_type: TTestType,
225) -> TestResult<f64>
226{
227    // Early exit for insufficient sample sizes
228    if n1 < 2.0 || n2 < 2.0 {
229        return TestResult::new(0.0, 1.0);
230    }
231
232    // Calculate means directly (avoiding redundant assignments)
233    let mean1 = sum1 / n1;
234    let mean2 = sum2 / n2;
235
236    // Calculate variances using the computational formula
237    let var1 = (sum_sq1 - sum1 * sum1 / n1) / (n1 - 1.0);
238    let var2 = (sum_sq2 - sum2 * sum2 / n2) / (n2 - 1.0);
239    
240    let mean_diff = mean1 - mean2;
241    
242    let (t_stat, df) = match test_type {
243        TTestType::Student => {
244            // Student's t-test (pooled variance)
245            let pooled_var = ((n1 - 1.0) * var1 + (n2 - 1.0) * var2) / (n1 + n2 - 2.0);
246            let std_err = (pooled_var * (1.0 / n1 + 1.0 / n2)).sqrt();
247            (mean_diff / std_err, n1 + n2 - 2.0)
248        }
249        TTestType::Welch => {
250            // Welch's t-test (unequal variances)
251            let term1 = var1 / n1;
252            let term2 = var2 / n2;
253            let combined_var = term1 + term2;
254            let std_err = combined_var.sqrt();
255            let t = mean_diff / std_err;
256            
257            // Welch-Satterthwaite equation for degrees of freedom
258            let df = combined_var * combined_var / 
259                (term1 * term1 / (n1 - 1.0) + term2 * term2 / (n2 - 1.0));
260            (t, df)
261        }
262    };
263
264    let p_value = fast_t_test_p_value(t_stat, df);
265    TestResult::new(t_stat, p_value)
266}
267
268#[inline]
269fn fast_t_test_p_value(t_stat: f64, df: f64) -> f64
270{
271    // Fast path for non-finite inputs
272    if !t_stat.is_finite() {
273        return if t_stat.is_infinite() { 0.0 } else { 1.0 };
274    }
275
276    if df <= 0.0 || !df.is_finite() {
277        return 1.0;
278    }
279
280    let abs_t = t_stat.abs();
281
282    // Fast path for very small t-statistics (common case)
283    if abs_t < 0.001 {
284        return 1.0; // p-value ≈ 1 for very small effects
285    }
286
287    // Early return for very large t-statistics (avoids expensive computations)
288    if abs_t > 37.0 {
289        let log_p = log_normal_tail_probability(abs_t);
290        return 2.0 * log_p.exp();
291    }
292
293    // Use normal approximation for large degrees of freedom (faster than t-distribution)
294    if df > 100.0 {
295        return 2.0 * high_precision_normal_cdf_complement(abs_t);
296    }
297
298    // Only create StudentsT distribution when necessary
299    match StudentsT::new(0.0, 1.0, df) {
300        Ok(t_dist) => {
301            let cdf_val = t_dist.cdf(abs_t);
302            2.0 * (1.0 - cdf_val)
303        }
304        Err(_) => 1.0,
305    }
306}
307
308/// High-precision calculation of log(P(Z > x)) for standard normal
309#[inline]
310fn log_normal_tail_probability(x: f64) -> f64 {
311    if x < 0.0 {
312        return 0.0; 
313    }
314    
315    if x > 8.0 {
316        let x_sq = x * x;
317        return -0.5 * x_sq - (x * (2.0 * std::f64::consts::PI).sqrt()).ln();
318    }
319
320    let z = x / (2.0_f64).sqrt();
321    log_erfc(z) - (2.0_f64).ln()
322}
323
324/// High-precision complementary error function for extreme values
325#[inline]
326fn log_erfc(x: f64) -> f64 {
327    if x < 0.0 {
328        return 0.0;
329    }
330    
331    if x > 26.0 {
332        let x_sq = x * x;
333        return -x_sq - 0.5 * (std::f64::consts::PI).ln() - x.ln();
334    }
335
336    continued_fraction_log_erfc(x)
337}
338
339/// Continued fraction approximation for log(erfc(x))
340#[inline]
341fn continued_fraction_log_erfc(x: f64) -> f64 {
342    if x < 2.0 {
343        let erf_val = erf_series(x);
344        return (1.0 - erf_val).ln();
345    }
346    
347    let x_sq = x * x;
348    let mut a = 1.0;
349    let mut b = 2.0 * x_sq;
350    let mut result = a / b;
351    
352    for n in 1..50 {
353        a = -(2 * n - 1) as f64;
354        b = 2.0 * x_sq + a / result;
355        let new_result = a / b;
356        
357        if (result - new_result).abs() < 1e-15 {
358            break;
359        }
360        result = new_result;
361    }
362    
363    -x_sq + (result / (x * (std::f64::consts::PI).sqrt())).ln()
364}
365
366/// Series expansion for erf(x) for small x
367#[inline]
368fn erf_series(x: f64) -> f64 {
369    let x_sq = x * x;
370    let mut term = x;
371    let mut result = term;
372    
373    for n in 1..100 {
374        term *= -x_sq / (n as f64);
375        let new_term = term / (2.0 * n as f64 + 1.0);
376        result += new_term;
377        
378        if new_term.abs() < 1e-16 {
379            break;
380        }
381    }
382    
383    result * 2.0 / (std::f64::consts::PI).sqrt()
384}
385
386/// High-precision normal CDF complement for extreme values
387#[inline]
388fn high_precision_normal_cdf_complement(x: f64) -> f64 {
389    if x < 0.0 {
390        return 1.0 - high_precision_normal_cdf_complement(-x);
391    }
392    
393    if x > 37.0 {
394        let log_p = log_normal_tail_probability(x);
395        return log_p.exp();
396    }
397    
398    0.5 * erfc_high_precision(x / (2.0_f64).sqrt())
399}
400
401/// High-precision complementary error function
402#[inline]
403fn erfc_high_precision(x: f64) -> f64 {
404    if x < 0.0 {
405        return 2.0 - erfc_high_precision(-x);
406    }
407    
408    if x > 26.0 {
409        return 0.0; 
410    }
411    
412    if x < 2.0 {
413        return 1.0 - erf_series(x);
414    }
415    chebyshev_erfc(x)
416}
417
418/// Chebyshev rational approximation for erfc
419#[inline]
420fn chebyshev_erfc(x: f64) -> f64 {
421    let a1 = 0.0705230784;
422    let a2 = 0.0422820123;
423    let a3 = 0.0092705272;
424    let a4 = 0.0001520143;
425    let a5 = 0.0002765672;
426    let a6 = 0.0000430638;
427    
428    let t = 1.0 / (1.0 + 0.3275911 * x);
429    let poly = t * (a1 + t * (a2 + t * (a3 + t * (a4 + t * (a5 + t * a6)))));
430    
431    poly * (-x * x).exp()
432}