Skip to main content

single_statistics/testing/
utils.rs

1//! Utility functions for statistical testing operations.
2
3use single_utilities::traits::FloatOpsTS;
4use num_traits::AsPrimitive;
5
6/// A lightweight reference-based representation of a sparse matrix (CSR or CSC).
7/// 
8/// This structure is designed to be agnostic of the underlying container and can be 
9/// easily used with raw vectors from other crates or FFI (like PyO3).
10#[derive(Debug, Clone, Copy)]
11pub struct SparseMatrixRef<'a, T, N, I> {
12    /// Major indices (e.g., indptr in CSR/CSC)
13    pub maj_ind: &'a [N],
14    /// Minor indices (e.g., column indices in CSR, row indices in CSC)
15    pub min_ind: &'a [I],
16    /// The actual values in the matrix
17    pub val: &'a [T],
18    /// Number of rows in the matrix
19    pub n_rows: usize,
20    /// Number of columns in the matrix
21    pub n_cols: usize,
22}
23
24impl<'a, T, N, I> SparseMatrixRef<'a, T, N, I>
25where
26    T: FloatOpsTS,
27    N: AsPrimitive<usize> + Send + Sync,
28    I: AsPrimitive<usize> + Send + Sync,
29{
30    /// Create a new SparseMatrixRef
31    pub fn new(maj_ind: &'a [N], min_ind: &'a [I], val: &'a [T], n_rows: usize, n_cols: usize) -> Self {
32        Self { maj_ind, min_ind, val, n_rows, n_cols }
33    }
34
35    /// Get the data for a specific major index (row in CSR, column in CSC)
36    #[inline]
37    pub fn get_major(&self, idx: usize) -> (&'a [I], &'a [T]) {
38        let start: usize = self.maj_ind[idx].as_();
39        let end: usize = self.maj_ind[idx + 1].as_();
40        (&self.min_ind[start..end], &self.val[start..end])
41    }
42
43    /// Get a specific entry (row, col) from the sparse matrix.
44    /// 
45    /// This assumes CSR format (row is major index).
46    pub fn get_entry(&self, row: usize, col: usize) -> T {
47        let (indices, values) = self.get_major(row);
48        match indices.binary_search_by(|&i| i.as_().cmp(&col)) {
49            Ok(idx) => values[idx],
50            Err(_) => T::zero(),
51        }
52    }
53}
54
55/// Extract unique group identifiers from a group assignment vector.
56///
57/// Returns a sorted vector of unique group IDs, removing duplicates.
58pub fn extract_unique_groups(group_ids: &[usize]) -> Vec<usize> {
59    let mut unique_groups = group_ids.to_vec();
60    unique_groups.sort();
61    unique_groups.dedup();
62    unique_groups
63}
64
65/// Extract indices of cells belonging to each of the two groups.
66///
67/// Returns a tuple of (group1_indices, group2_indices) where each vector contains
68/// the row/column indices of cells belonging to that group.
69pub fn get_group_indices(group_ids: &[usize], unique_groups: &[usize]) -> (Vec<usize>, Vec<usize>) {
70    let group1 = unique_groups[0];
71    let group2 = unique_groups[1];
72
73    let group1_indices = group_ids.iter()
74        .enumerate()
75        .filter_map(|(i, &g)| if g == group1 { Some(i) } else { None })
76        .collect();
77
78    let group2_indices = group_ids.iter()
79        .enumerate()
80        .filter_map(|(i, &g)| if g == group2 { Some(i) } else { None })
81        .collect();
82
83    (group1_indices, group2_indices)
84}
85
86/// Generic version of statistics accumulation that works with any SparseMatrixRef.
87/// 
88/// Assumes matrix is Genes (rows) x Cells (cols).
89pub(crate) fn accumulate_gene_statistics_two_groups_raw<T, N, I>(
90    matrix: SparseMatrixRef<T, N, I>,
91    group1_indices: &[usize],
92    group2_indices: &[usize],
93    _n_features: usize,
94) -> anyhow::Result<(Vec<T>, Vec<T>, Vec<T>, Vec<T>)>
95where
96    T: FloatOpsTS,
97    N: AsPrimitive<usize> + Send + Sync,
98    I: AsPrimitive<usize> + Send + Sync,
99{
100    let n_genes = matrix.n_rows;
101    let n_cells = matrix.n_cols;
102
103    // Create a mapping for group membership to avoid repeated linear searches
104    let mut cell_groups = vec![0u8; n_cells];
105    for &idx in group1_indices { cell_groups[idx] = 1; }
106    for &idx in group2_indices { cell_groups[idx] = 2; }
107
108    // Pre-allocate all accumulation vectors (one per gene/row)
109    let mut group1_sums = vec![T::zero(); n_genes];
110    let mut group1_sum_squares = vec![T::zero(); n_genes];
111    let mut group2_sums = vec![T::zero(); n_genes];
112    let mut group2_sum_squares = vec![T::zero(); n_genes];
113    
114    for row_idx in 0..n_genes {
115        let (cols, vals) = matrix.get_major(row_idx);
116        for (col_idx, &value) in cols.iter().zip(vals.iter()) {
117            let c_idx: usize = col_idx.as_();
118            match cell_groups[c_idx] {
119                1 => {
120                    group1_sums[row_idx] += value;
121                    group1_sum_squares[row_idx] += value * value;
122                }
123                2 => {
124                    group2_sums[row_idx] += value;
125                    group2_sum_squares[row_idx] += value * value;
126                }
127                _ => {}
128            }
129        }
130    }
131
132    Ok((group1_sums, group1_sum_squares, group2_sums, group2_sum_squares))
133}