1pub(crate) mod processing;
2pub(crate) mod statistics;
3pub(crate) mod utils;
4
5pub use processing::get_select_info_obs;
6pub use processing::get_select_info_vars;
7pub use processing::FlavorType;
8pub use processing::HVGParams;
9use std::collections::HashMap;
10
11use anndata::backend::ScalarType;
12use anndata::data::DynCsrMatrix;
13use anndata::data::{DynArray, DynCscMatrix, SelectInfoElem};
14use anndata::{data::Shape, ArrayData, HasShape};
15use anyhow::{anyhow, bail};
16use nalgebra_sparse::{CscMatrix, CsrMatrix};
17use ndarray::{Array2, ArrayD, Ix2};
18use num_traits::{NumCast, Zero};
19use single_utilities::traits::NumericOps;
20use utils::select_info_elem_to_indices;
21
22pub enum FeatureSelection {
23 HighlyVariableCol(String),
24 HighlyVariable(usize),
25 Randomized(usize),
26 VarianceThreshold(f64),
27 None,
28}
29
30pub enum ComputationMode {
31 Chunked(usize),
32 Whole,
33}
34
35impl Clone for ComputationMode {
36 fn clone(&self) -> Self {
37 match self {
38 Self::Chunked(arg0) => Self::Chunked(*arg0),
39 Self::Whole => Self::Whole,
40 }
41 }
42}
43
44pub enum FlexValue {
46 Absolute(f32),
47 Relative(f32),
48 None,
49}
50
51impl Clone for FlexValue {
52 fn clone(&self) -> Self {
53 match self {
54 Self::Absolute(arg0) => Self::Absolute(*arg0),
55 Self::Relative(arg0) => Self::Relative(*arg0),
56 Self::None => Self::None,
57 }
58 }
59}
60
61impl FlexValue {
62 pub fn is_absolute(&self) -> bool {
63 match self {
64 Self::Absolute(_) => true,
65 Self::Relative(_) => false,
66 Self::None => false,
67 }
68 }
69
70 pub fn is_relative(&self) -> bool {
71 match self {
72 Self::Absolute(_) => false,
73 Self::Relative(_) => true,
74 Self::None => false,
75 }
76 }
77
78 pub fn is_none(&self) -> bool {
79 match self {
80 Self::Absolute(_) => false,
81 Self::Relative(_) => false,
82 Self::None => true,
83 }
84 }
85
86 pub fn is_some(&self) -> bool {
87 !self.is_none()
88 }
89}
90
91#[macro_export]
92macro_rules! match_dyn_csr_matrix {
93 ($csr:expr, $fun:ident, $($arg:expr),*) => {
94 match $csr {
95 DynCsrMatrix::I8(d) => $fun(d, $($arg),*),
96 DynCsrMatrix::I16(d) => $fun(d, $($arg),*),
97 DynCsrMatrix::I32(d) => $fun(d, $($arg),*),
98 DynCsrMatrix::I64(_d) => panic!("I64 CSR matrices are not supported for this operation"),
99 DynCsrMatrix::U8(d) => $fun(d, $($arg),*),
100 DynCsrMatrix::U16(d) => $fun(d, $($arg),*),
101 DynCsrMatrix::U32(d) => $fun(d, $($arg),*),
102 DynCsrMatrix::U64(_d) => panic!("U64 CSR matrices are not supported for this operation"),
103 DynCsrMatrix::F32(d) => $fun(d, $($arg),*),
104 DynCsrMatrix::F64(d) => $fun(d, $($arg),*),
105 DynCsrMatrix::Bool(_) => panic!("Boolean CSR matrices are not supported for this operation"),
106 DynCsrMatrix::String(_) => panic!("String CSR matrices are not supported for this operation"),
107 }
108 };
109}
110
111#[macro_export]
112macro_rules! match_dyn_csc_matrix {
113 ($csc:expr, $fun:ident, $($arg:expr),*) => {
114 match $csc {
115 DynCscMatrix::I8(d) => $fun(d, $($arg),*),
116 DynCscMatrix::I16(d) => $fun(d, $($arg),*),
117 DynCscMatrix::I32(d) => $fun(d, $($arg),*),
118 DynCscMatrix::I64(_d) => panic!("I64 CSC matrices are not supported for this operation"),
119 DynCscMatrix::U8(d) => $fun(d, $($arg),*),
120 DynCscMatrix::U16(d) => $fun(d, $($arg),*),
121 DynCscMatrix::U32(d) => $fun(d, $($arg),*),
122 DynCscMatrix::U64(_d) => panic!("U64 CSC matrices are not supported for this operation"),
123 DynCscMatrix::F32(d) => $fun(d, $($arg),*),
124 DynCscMatrix::F64(d) => $fun(d, $($arg),*),
125 DynCscMatrix::Bool(_) => panic!("Boolean CSC matrices are not supported for this operation"),
126 DynCscMatrix::String(_) => panic!("String CSC matrices are not supported for this operation"),
127 }
128 };
129}
130
131#[macro_export]
132macro_rules! match_array_data_apply_function {
133 ($data:expr, $fun:ident) => {
135 match $data {
136 anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
137 match dyn_csr_matrix {
138 anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun(),
139 anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun(),
140 anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun(),
141 anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun(),
142 anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun(),
143 anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun(),
144 anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun(),
145 anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun(),
146 anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun(),
147 anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun(),
148 _ => bail!("This operation is only supported on numeric types!")
149 }
150 },
151 anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
152 match dyn_csc_matrix {
153 anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun(),
154 anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun(),
155 anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun(),
156 anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun(),
157 anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun(),
158 anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun(),
159 anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun(),
160 anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun(),
161 anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun(),
162 anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun(),
163 _ => bail!("This operation is only supported on numeric types!")
164 }
165 },
166 _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
167 }
168 };
169
170 ($data:expr, $fun:ident, $($arg:expr),+) => {
172 match $data {
173 anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
174 match dyn_csr_matrix {
175 anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun($($arg),*),
176 anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun($($arg),*),
177 anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun($($arg),*),
178 anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun($($arg),*),
179 anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun($($arg),*),
180 anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun($($arg),*),
181 anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun($($arg),*),
182 anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun($($arg),*),
183 anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun($($arg),*),
184 anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun($($arg),*),
185 _ => bail!("This operation is only supported on numeric types!")
186 }
187 },
188 anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
189 match dyn_csc_matrix {
190 anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun($($arg),*),
191 anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun($($arg),*),
192 anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun($($arg),*),
193 anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun($($arg),*),
194 anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun($($arg),*),
195 anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun($($arg),*),
196 anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun($($arg),*),
197 anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun($($arg),*),
198 anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun($($arg),*),
199 anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun($($arg),*),
200 _ => bail!("This operation is only supported on numeric types!")
201 }
202 },
203 _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
204 }
205 };
206}
207
208#[macro_export]
209macro_rules! match_array_data_apply_function_with_generics {
210 ($data:expr, $fun:ident, [$($types:ty),+]) => {
212 match $data {
213 anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
214 match dyn_csr_matrix {
215 anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
216 anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
217 anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
218 anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
219 anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
220 anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
221 anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
222 anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
223 anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
224 anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
225 _ => bail!("This operation is only supported on numeric types!")
226 }
227 },
228 anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
229 match dyn_csc_matrix {
230 anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
231 anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
232 anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
233 anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
234 anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
235 anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
236 anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
237 anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
238 anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
239 anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
240 _ => bail!("This operation is only supported on numeric types!")
241 }
242 },
243 _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
244 }
245 };
246
247 ($data:expr, $fun:ident, [$($types:ty),+], $($arg:expr),+) => {
249 match $data {
250 anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
251 match dyn_csr_matrix {
252 anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
253 anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
254 anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
255 anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
256 anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
257 anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
258 anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
259 anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
260 anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
261 anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
262 _ => bail!("This operation is only supported on numeric types!")
263 }
264 },
265 anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
266 match dyn_csc_matrix {
267 anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
268 anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
269 anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
270 anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
271 anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
272 anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
273 anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
274 anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
275 anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
276 anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
277 _ => bail!("This operation is only supported on numeric types!")
278 }
279 },
280 _ => bail!("This operation is currently only supported for CSC and CSR matrices.")
281 }
282 };
283}
284
285pub fn convert_to_array_f64(arr_data: &ArrayData) -> anyhow::Result<Array2<f64>> {
286 let shape = arr_data.shape();
287 match arr_data {
288 ArrayData::Array(array) => convert_to_array_f64_array(array),
289 ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(csr, convert_to_array_f64_csr, shape),
290 ArrayData::CsrNonCanonical(_csc) => todo!(),
291 ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(csc, convert_to_array_f64_csc, shape),
292 ArrayData::DataFrame(_) => todo!(),
293 }
294}
295
296fn convert_to_array_f64_array(darray: &DynArray) -> anyhow::Result<Array2<f64>> {
297 match darray {
298 DynArray::I8(arr) => convert_arrayd_to_array2_f64(arr),
299 DynArray::I16(arr) => convert_arrayd_to_array2_f64(arr),
300 DynArray::I32(arr) => convert_arrayd_to_array2_f64(arr),
301 DynArray::I64(_) => todo!(),
302 DynArray::U8(arr) => convert_arrayd_to_array2_f64(arr),
303 DynArray::U16(arr) => convert_arrayd_to_array2_f64(arr),
304 DynArray::U32(arr) => convert_arrayd_to_array2_f64(arr),
305 DynArray::U64(_) => todo!(),
306 DynArray::F32(arr) => convert_arrayd_to_array2_f64(arr),
307 DynArray::F64(array) => convert_arrayd_to_array2_f64(array),
308 DynArray::Bool(_) => todo!(),
309 DynArray::String(_) => todo!(),
310 }
311}
312
313fn convert_arrayd_to_array2_f64<T: NumericOps>(arrayd: &ArrayD<T>) -> anyhow::Result<Array2<f64>> {
314 let shape = arrayd.shape();
315
316 match shape.len() {
317 1 => Err(anyhow!("The ArrayD must have at least two dimensions!")),
318 2 => Ok(arrayd
319 .mapv(|x| NumCast::from(x).unwrap_or_else(f64::zero))
320 .into_dimensionality::<Ix2>()?),
321 _ => {
322 let rows = shape[0];
323 let cols = shape[1..].iter().product();
324 let flat_data: Vec<f64> = arrayd
325 .iter()
326 .map(|&x| NumCast::from(x).unwrap_or_else(f64::zero))
327 .collect();
328
329 let data = Array2::from_shape_vec((rows, cols), flat_data)?;
330 Ok(data)
331 }
332 }
333}
334
335fn convert_to_array_f64_csc<T: NumericOps>(
336 csc: &CscMatrix<T>,
337 shape: Shape,
338) -> anyhow::Result<Array2<f64>> {
339 let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
340 for (col, vec) in csc.col_iter().enumerate() {
341 for (&row, val) in vec.row_indices().iter().zip(csc.values()) {
342 dense[[row, col]] = NumCast::from(*val).unwrap();
343 }
344 }
345 Ok(dense)
346}
347
348fn convert_to_array_f64_csr<T: NumericOps>(
349 csr: &CsrMatrix<T>,
350 shape: Shape,
351) -> anyhow::Result<Array2<f64>> {
352 let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
353 for (row, vec) in csr.row_iter().enumerate() {
354 for (&col, val) in vec.col_indices().iter().zip(csr.values()) {
355 dense[[row, col]] = NumCast::from(*val).unwrap();
356 }
357 }
358 Ok(dense)
359}
360
361fn convert_to_array_f64_csr_selected<T: NumericOps>(
362 csr: &CsrMatrix<T>,
363 shape: Shape,
364 row_selection: &SelectInfoElem,
365 col_selection: &SelectInfoElem,
366) -> anyhow::Result<Array2<f64>> {
367 let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
368 let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
369 let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
370
371 let col_map: HashMap<usize, usize> = col_indices
373 .iter()
374 .enumerate()
375 .map(|(i, &col)| (col, i))
376 .collect();
377
378 for (out_row, &row) in row_indices.iter().enumerate() {
379 if row < csr.nrows() {
380 let row_start = csr.row_offsets()[row];
381 let row_end = csr.row_offsets()[row + 1];
382 for (&col, &value) in csr.col_indices()[row_start..row_end]
383 .iter()
384 .zip(csr.values()[row_start..row_end].iter())
385 {
386 if let Some(&out_col) = col_map.get(&col) {
387 dense[[out_row, out_col]] = NumCast::from(value)
388 .ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
389 }
390 }
391 }
392 }
393 Ok(dense)
394}
395
396fn convert_to_array_f64_csc_selected<T: NumericOps>(
397 csc: &CscMatrix<T>,
398 shape: Shape,
399 row_selection: &SelectInfoElem,
400 col_selection: &SelectInfoElem,
401) -> anyhow::Result<Array2<f64>> {
402 let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
403 let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
404 let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
405
406 let row_map: HashMap<usize, usize> = row_indices
408 .iter()
409 .enumerate()
410 .map(|(i, &row)| (row, i))
411 .collect();
412
413 for (out_col, &col) in col_indices.iter().enumerate() {
414 if col < csc.ncols() {
415 let col_start = csc.col_offsets()[col];
416 let col_end = csc.col_offsets()[col + 1];
417 for (&row, &value) in csc.row_indices()[col_start..col_end]
418 .iter()
419 .zip(csc.values()[col_start..col_end].iter())
420 {
421 if let Some(&out_row) = row_map.get(&row) {
422 dense[[out_row, out_col]] = NumCast::from(value)
423 .ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
424 }
425 }
426 }
427 }
428 Ok(dense)
429}
430
431pub fn convert_to_array_f64_selected(
432 data: &ArrayData,
433 shape: Shape,
434 row_selection: &SelectInfoElem,
435 col_selection: &SelectInfoElem,
436) -> anyhow::Result<Array2<f64>> {
437 match data {
438 ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(
439 csc,
440 convert_to_array_f64_csc_selected,
441 shape,
442 row_selection,
443 col_selection
444 ),
445 ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(
446 csr,
447 convert_to_array_f64_csr_selected,
448 shape,
449 row_selection,
450 col_selection
451 ),
452 _ => anyhow::bail!("Unsupported data type for conversion to Array2<f64>"),
453 }
454}
455
456pub fn need_conversion_target_float_type(scalar_type: &ScalarType) -> anyhow::Result<bool> {
457 match scalar_type {
458 ScalarType::I8 => Ok(true),
459 ScalarType::I16 => Ok(true),
460 ScalarType::I32 => Ok(true),
461 ScalarType::I64 => Ok(true),
462 ScalarType::U8 => Ok(true),
463 ScalarType::U16 => Ok(true),
464 ScalarType::U32 => Ok(true),
465 ScalarType::U64 => Ok(true),
466 ScalarType::F32 => Ok(false),
467 ScalarType::F64 => Ok(false),
468 ScalarType::Bool => {
469 bail!("Cannot use a Scalar of type <Bool> in the normalization procedure.")
470 }
471 ScalarType::String => {
472 bail!("Cannot use a Scalar of type <String> in the normalization procedure.")
473 }
474 }
475}
476
477#[derive(Default, Debug, Clone, Copy)]
478pub enum Precision {
479 #[default]
480 Single,
481 Double,
482}