1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
//! A simple trait to wrap SVD Computation.

use crate::svd::SVD;
use ndarray::ArrayView2;
use ndarray_linalg::{SVDDCInto, UVTFlag};
use crate::types::{c32, c64, Result, Scalar, RustyCompressionError};

pub(crate) trait ComputeSVD {
    type A: Scalar;

    fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>>;
}

macro_rules! compute_svd_impl {
    ($scalar:ty) => {
        impl ComputeSVD for $scalar {
            type A = $scalar;
            fn compute_svd(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>> {
                let result = arr.to_owned().svddc_into(UVTFlag::Some);

                let (u, s, vt) = match result {
                    Ok((u, s, vt)) => (u.unwrap(), s, vt.unwrap()),
                    Err(err) => return Err(RustyCompressionError::LinalgError(err)),
                };

                Ok(SVD { u, s, vt })
            }
        }
    };
}

compute_svd_impl!(f32);
compute_svd_impl!(f64);
compute_svd_impl!(c32);
compute_svd_impl!(c64);