rusty_compression/
compute_svd.rs

1//! A simple trait to wrap SVD Computation.
2
3use crate::svd::SVD;
4use ndarray::ArrayView2;
5use ndarray_linalg::{SVDDCInto, UVTFlag};
6use crate::types::{c32, c64, Result, Scalar, RustyCompressionError};
7
8pub(crate) trait ComputeSVD {
9    type A: Scalar;
10
11    fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>>;
12}
13
14macro_rules! compute_svd_impl {
15    ($scalar:ty) => {
16        impl ComputeSVD for $scalar {
17            type A = $scalar;
18            fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>> {
19                let result = arr.to_owned().svddc_into(UVTFlag::Some);
20
21                let (u, s, vt) = match result {
22                    Ok((u, s, vt)) => (u.unwrap(), s, vt.unwrap()),
23                    Err(err) => return Err(RustyCompressionError::LinalgError(err)),
24                };
25
26                Ok(SVD { u, s, vt })
27            }
28        }
29    };
30}
31
32compute_svd_impl!(f32);
33compute_svd_impl!(f64);
34compute_svd_impl!(c32);
35compute_svd_impl!(c64);