single_rust/shared/
mod.rs

1pub(crate) mod processing;
2pub(crate) mod statistics;
3pub(crate) mod utils;
4
5pub use processing::get_select_info_obs;
6pub use processing::get_select_info_vars;
7pub use processing::FlavorType;
8pub use processing::HVGParams;
9use std::collections::HashMap;
10
11use anndata::backend::ScalarType;
12use anndata::data::DynCsrMatrix;
13use anndata::data::{DynArray, DynCscMatrix, SelectInfoElem};
14use anndata::{data::Shape, ArrayData, HasShape};
15use anyhow::{anyhow, bail};
16use nalgebra_sparse::{CscMatrix, CsrMatrix};
17use ndarray::{Array2, ArrayD, Ix2};
18use num_traits::{NumCast, Zero};
19use single_utilities::traits::NumericOps;
20use utils::select_info_elem_to_indices;
21
22pub enum FeatureSelection {
23    HighlyVariableCol(String),
24    HighlyVariable(usize),
25    Randomized(usize),
26    VarianceThreshold(f64),
27    None,
28}
29
30pub enum ComputationMode {
31    Chunked(usize),
32    Whole,
33}
34
35impl Clone for ComputationMode {
36    fn clone(&self) -> Self {
37        match self {
38            Self::Chunked(arg0) => Self::Chunked(*arg0),
39            Self::Whole => Self::Whole,
40        }
41    }
42}
43
44// TODO: implement more flexibility here!
45pub enum FlexValue {
46    Absolute(f32),
47    Relative(f32),
48    None,
49}
50
51impl Clone for FlexValue {
52    fn clone(&self) -> Self {
53        match self {
54            Self::Absolute(arg0) => Self::Absolute(*arg0),
55            Self::Relative(arg0) => Self::Relative(*arg0),
56            Self::None => Self::None,
57        }
58    }
59}
60
61impl FlexValue {
62    pub fn is_absolute(&self) -> bool {
63        match self {
64            Self::Absolute(_) => true,
65            Self::Relative(_) => false,
66            Self::None => false,
67        }
68    }
69
70    pub fn is_relative(&self) -> bool {
71        match self {
72            Self::Absolute(_) => false,
73            Self::Relative(_) => true,
74            Self::None => false,
75        }
76    }
77
78    pub fn is_none(&self) -> bool {
79        match self {
80            Self::Absolute(_) => false,
81            Self::Relative(_) => false,
82            Self::None => true,
83        }
84    }
85
86    pub fn is_some(&self) -> bool {
87        !self.is_none()
88    }
89}
90
91#[macro_export]
92macro_rules! match_dyn_csr_matrix {
93    ($csr:expr, $fun:ident, $($arg:expr),*) => {
94        match $csr {
95            DynCsrMatrix::I8(d) => $fun(d, $($arg),*),
96            DynCsrMatrix::I16(d) => $fun(d, $($arg),*),
97            DynCsrMatrix::I32(d) => $fun(d, $($arg),*),
98            DynCsrMatrix::I64(_d) => panic!("I64 CSR matrices are not supported for this operation"),
99            DynCsrMatrix::U8(d) => $fun(d, $($arg),*),
100            DynCsrMatrix::U16(d) => $fun(d, $($arg),*),
101            DynCsrMatrix::U32(d) => $fun(d, $($arg),*),
102            DynCsrMatrix::U64(_d) => panic!("U64 CSR matrices are not supported for this operation"),
103            DynCsrMatrix::F32(d) => $fun(d, $($arg),*),
104            DynCsrMatrix::F64(d) => $fun(d, $($arg),*),
105            DynCsrMatrix::Bool(_) => panic!("Boolean CSR matrices are not supported for this operation"),
106            DynCsrMatrix::String(_) => panic!("String CSR matrices are not supported for this operation"),
107        }
108    };
109}
110
111#[macro_export]
112macro_rules! match_dyn_csc_matrix {
113    ($csc:expr, $fun:ident, $($arg:expr),*) => {
114        match $csc {
115            DynCscMatrix::I8(d) => $fun(d, $($arg),*),
116            DynCscMatrix::I16(d) => $fun(d, $($arg),*),
117            DynCscMatrix::I32(d) => $fun(d, $($arg),*),
118            DynCscMatrix::I64(_d) => panic!("I64 CSC matrices are not supported for this operation"),
119            DynCscMatrix::U8(d) => $fun(d, $($arg),*),
120            DynCscMatrix::U16(d) => $fun(d, $($arg),*),
121            DynCscMatrix::U32(d) => $fun(d, $($arg),*),
122            DynCscMatrix::U64(_d) => panic!("U64 CSC matrices are not supported for this operation"),
123            DynCscMatrix::F32(d) => $fun(d, $($arg),*),
124            DynCscMatrix::F64(d) => $fun(d, $($arg),*),
125            DynCscMatrix::Bool(_) => panic!("Boolean CSC matrices are not supported for this operation"),
126            DynCscMatrix::String(_) => panic!("String CSC matrices are not supported for this operation"),
127        }
128    };
129}
130
131#[macro_export]
132macro_rules! match_array_data_apply_function {
133    // Pattern for no arguments
134    ($data:expr, $fun:ident) => {
135        match $data {
136            anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
137                match dyn_csr_matrix {
138                    anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun(),
139                    anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun(),
140                    anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun(),
141                    anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun(),
142                    anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun(),
143                    anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun(),
144                    anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun(),
145                    anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun(),
146                    anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun(),
147                    anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun(),
148                    _ => bail!("This operation is only supported on numeric types!")
149                }
150            },
151            anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
152                match dyn_csc_matrix {
153                    anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun(),
154                    anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun(),
155                    anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun(),
156                    anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun(),
157                    anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun(),
158                    anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun(),
159                    anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun(),
160                    anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun(),
161                    anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun(),
162                    anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun(),
163                    _ => bail!("This operation is only supported on numeric types!")
164                }
165            },
166            _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
167        }
168    };
169
170    // Pattern for one or more arguments
171    ($data:expr, $fun:ident, $($arg:expr),+) => {
172        match $data {
173            anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
174                match dyn_csr_matrix {
175                    anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun($($arg),*),
176                    anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun($($arg),*),
177                    anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun($($arg),*),
178                    anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun($($arg),*),
179                    anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun($($arg),*),
180                    anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun($($arg),*),
181                    anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun($($arg),*),
182                    anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun($($arg),*),
183                    anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun($($arg),*),
184                    anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun($($arg),*),
185                    _ => bail!("This operation is only supported on numeric types!")
186                }
187            },
188            anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
189                match dyn_csc_matrix {
190                    anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun($($arg),*),
191                    anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun($($arg),*),
192                    anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun($($arg),*),
193                    anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun($($arg),*),
194                    anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun($($arg),*),
195                    anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun($($arg),*),
196                    anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun($($arg),*),
197                    anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun($($arg),*),
198                    anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun($($arg),*),
199                    anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun($($arg),*),
200                    _ => bail!("This operation is only supported on numeric types!")
201                }
202            },
203            _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
204        }
205    };
206}
207
208#[macro_export]
209macro_rules! match_array_data_apply_function_with_generics {
210    // Pattern for generics and no arguments
211    ($data:expr, $fun:ident, [$($types:ty),+]) => {
212        match $data {
213            anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
214                match dyn_csr_matrix {
215                    anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
216                    anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
217                    anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
218                    anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
219                    anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
220                    anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
221                    anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
222                    anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
223                    anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
224                    anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
225                    _ => bail!("This operation is only supported on numeric types!")
226                }
227            },
228            anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
229                match dyn_csc_matrix {
230                    anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
231                    anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
232                    anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
233                    anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
234                    anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
235                    anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
236                    anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
237                    anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
238                    anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
239                    anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
240                    _ => bail!("This operation is only supported on numeric types!")
241                }
242            },
243            _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
244        }
245    };
246
247    // Pattern for generics and arguments
248    ($data:expr, $fun:ident, [$($types:ty),+], $($arg:expr),+) => {
249        match $data {
250            anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
251                match dyn_csr_matrix {
252                    anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
253                    anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
254                    anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
255                    anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
256                    anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
257                    anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
258                    anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
259                    anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
260                    anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
261                    anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
262                    _ => bail!("This operation is only supported on numeric types!")
263                }
264            },
265            anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
266                match dyn_csc_matrix {
267                    anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
268                    anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
269                    anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
270                    anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
271                    anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
272                    anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
273                    anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
274                    anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
275                    anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
276                    anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
277                    _ => bail!("This operation is only supported on numeric types!")
278                }
279            },
280            _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
281        }
282    };
283}
284
285pub fn convert_to_array_f64(arr_data: &ArrayData) -> anyhow::Result<Array2<f64>> {
286    let shape = arr_data.shape();
287    match arr_data {
288        ArrayData::Array(array) => convert_to_array_f64_array(array),
289        ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(csr, convert_to_array_f64_csr, shape),
290        ArrayData::CsrNonCanonical(_csc) => todo!(),
291        ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(csc, convert_to_array_f64_csc, shape),
292        ArrayData::DataFrame(_) => todo!(),
293    }
294}
295
296fn convert_to_array_f64_array(darray: &DynArray) -> anyhow::Result<Array2<f64>> {
297    match darray {
298        DynArray::I8(arr) => convert_arrayd_to_array2_f64(arr),
299        DynArray::I16(arr) => convert_arrayd_to_array2_f64(arr),
300        DynArray::I32(arr) => convert_arrayd_to_array2_f64(arr),
301        DynArray::I64(_) => todo!(),
302        DynArray::U8(arr) => convert_arrayd_to_array2_f64(arr),
303        DynArray::U16(arr) => convert_arrayd_to_array2_f64(arr),
304        DynArray::U32(arr) => convert_arrayd_to_array2_f64(arr),
305        DynArray::U64(_) => todo!(),
306        DynArray::F32(arr) => convert_arrayd_to_array2_f64(arr),
307        DynArray::F64(array) => convert_arrayd_to_array2_f64(array),
308        DynArray::Bool(_) => todo!(),
309        DynArray::String(_) => todo!(),
310    }
311}
312
313fn convert_arrayd_to_array2_f64<T: NumericOps>(arrayd: &ArrayD<T>) -> anyhow::Result<Array2<f64>> {
314    let shape = arrayd.shape();
315
316    match shape.len() {
317        1 => Err(anyhow!("The ArrayD must have at least two dimensions!")),
318        2 => Ok(arrayd
319            .mapv(|x| NumCast::from(x).unwrap_or_else(f64::zero))
320            .into_dimensionality::<Ix2>()?),
321        _ => {
322            let rows = shape[0];
323            let cols = shape[1..].iter().product();
324            let flat_data: Vec<f64> = arrayd
325                .iter()
326                .map(|&x| NumCast::from(x).unwrap_or_else(f64::zero))
327                .collect();
328
329            let data = Array2::from_shape_vec((rows, cols), flat_data)?;
330            Ok(data)
331        }
332    }
333}
334
335fn convert_to_array_f64_csc<T: NumericOps>(
336    csc: &CscMatrix<T>,
337    shape: Shape,
338) -> anyhow::Result<Array2<f64>> {
339    let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
340    for (col, vec) in csc.col_iter().enumerate() {
341        for (&row, val) in vec.row_indices().iter().zip(csc.values()) {
342            dense[[row, col]] = NumCast::from(*val).unwrap();
343        }
344    }
345    Ok(dense)
346}
347
348fn convert_to_array_f64_csr<T: NumericOps>(
349    csr: &CsrMatrix<T>,
350    shape: Shape,
351) -> anyhow::Result<Array2<f64>> {
352    let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
353    for (row, vec) in csr.row_iter().enumerate() {
354        for (&col, val) in vec.col_indices().iter().zip(csr.values()) {
355            dense[[row, col]] = NumCast::from(*val).unwrap();
356        }
357    }
358    Ok(dense)
359}
360
361fn convert_to_array_f64_csr_selected<T: NumericOps>(
362    csr: &CsrMatrix<T>,
363    shape: Shape,
364    row_selection: &SelectInfoElem,
365    col_selection: &SelectInfoElem,
366) -> anyhow::Result<Array2<f64>> {
367    let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
368    let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
369    let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
370
371    // Create a mapping from original column indices to output column indices
372    let col_map: HashMap<usize, usize> = col_indices
373        .iter()
374        .enumerate()
375        .map(|(i, &col)| (col, i))
376        .collect();
377
378    for (out_row, &row) in row_indices.iter().enumerate() {
379        if row < csr.nrows() {
380            let row_start = csr.row_offsets()[row];
381            let row_end = csr.row_offsets()[row + 1];
382            for (&col, &value) in csr.col_indices()[row_start..row_end]
383                .iter()
384                .zip(csr.values()[row_start..row_end].iter())
385            {
386                if let Some(&out_col) = col_map.get(&col) {
387                    dense[[out_row, out_col]] = NumCast::from(value)
388                        .ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
389                }
390            }
391        }
392    }
393    Ok(dense)
394}
395
396fn convert_to_array_f64_csc_selected<T: NumericOps>(
397    csc: &CscMatrix<T>,
398    shape: Shape,
399    row_selection: &SelectInfoElem,
400    col_selection: &SelectInfoElem,
401) -> anyhow::Result<Array2<f64>> {
402    let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
403    let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
404    let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
405
406    // Create a mapping from original row indices to output row indices
407    let row_map: HashMap<usize, usize> = row_indices
408        .iter()
409        .enumerate()
410        .map(|(i, &row)| (row, i))
411        .collect();
412
413    for (out_col, &col) in col_indices.iter().enumerate() {
414        if col < csc.ncols() {
415            let col_start = csc.col_offsets()[col];
416            let col_end = csc.col_offsets()[col + 1];
417            for (&row, &value) in csc.row_indices()[col_start..col_end]
418                .iter()
419                .zip(csc.values()[col_start..col_end].iter())
420            {
421                if let Some(&out_row) = row_map.get(&row) {
422                    dense[[out_row, out_col]] = NumCast::from(value)
423                        .ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
424                }
425            }
426        }
427    }
428    Ok(dense)
429}
430
431pub fn convert_to_array_f64_selected(
432    data: &ArrayData,
433    shape: Shape,
434    row_selection: &SelectInfoElem,
435    col_selection: &SelectInfoElem,
436) -> anyhow::Result<Array2<f64>> {
437    match data {
438        ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(
439            csc,
440            convert_to_array_f64_csc_selected,
441            shape,
442            row_selection,
443            col_selection
444        ),
445        ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(
446            csr,
447            convert_to_array_f64_csr_selected,
448            shape,
449            row_selection,
450            col_selection
451        ),
452        _ => anyhow::bail!("Unsupported data type for conversion to Array2<f64>"),
453    }
454}
455
456pub fn need_conversion_target_float_type(scalar_type: &ScalarType) -> anyhow::Result<bool> {
457    match scalar_type {
458        ScalarType::I8 => Ok(true),
459        ScalarType::I16 => Ok(true),
460        ScalarType::I32 => Ok(true),
461        ScalarType::I64 => Ok(true),
462        ScalarType::U8 => Ok(true),
463        ScalarType::U16 => Ok(true),
464        ScalarType::U32 => Ok(true),
465        ScalarType::U64 => Ok(true),
466        ScalarType::F32 => Ok(false),
467        ScalarType::F64 => Ok(false),
468        ScalarType::Bool => {
469            bail!("Cannot use a Scalar of type <Bool> in the normalization procedure.")
470        }
471        ScalarType::String => {
472            bail!("Cannot use a Scalar of type <String> in the normalization procedure.")
473        }
474    }
475}
476
477#[derive(Default, Debug, Clone, Copy)]
478pub enum Precision {
479    #[default]
480    Single,
481    Double,
482}