rusty_compression/
compute_svd.rs1use crate::svd::SVD;
4use ndarray::ArrayView2;
5use ndarray_linalg::{SVDDCInto, UVTFlag};
6use crate::types::{c32, c64, Result, Scalar, RustyCompressionError};
7
8pub(crate) trait ComputeSVD {
9 type A: Scalar;
10
11 fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>>;
12}
13
14macro_rules! compute_svd_impl {
15 ($scalar:ty) => {
16 impl ComputeSVD for $scalar {
17 type A = $scalar;
18 fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>> {
19 let result = arr.to_owned().svddc_into(UVTFlag::Some);
20
21 let (u, s, vt) = match result {
22 Ok((u, s, vt)) => (u.unwrap(), s, vt.unwrap()),
23 Err(err) => return Err(RustyCompressionError::LinalgError(err)),
24 };
25
26 Ok(SVD { u, s, vt })
27 }
28 }
29 };
30}
31
32compute_svd_impl!(f32);
33compute_svd_impl!(f64);
34compute_svd_impl!(c32);
35compute_svd_impl!(c64);