rusty_compression/
types.rs

1//! This module collects the various traits definitions
2
3use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
4use ndarray_linalg::error::LinalgError;
5use ndarray_linalg::Norm;
6use ndarray_linalg::OperationNorm;
7use thiserror::Error;
8
9pub use ndarray_linalg::{c32, c64, Scalar};
10
11#[derive(Error, Debug)]
12pub enum RustyCompressionError {
13    #[error("Lapack Error")]
14    LinalgError(LinalgError),
15    #[error("Could not compress to desired tolerance")]
16    CompressionError,
17    #[error("Incompatible memory layout")]
18    LayoutError,
19    #[error("Pivoted QR failed")]
20    PivotedQRError,
21}
22
23pub type Result<T> = std::result::Result<T, RustyCompressionError>;
24
25pub trait Apply<A, Lhs> {
26    type Output;
27
28    fn dot(&self, lhs: &Lhs) -> Self::Output;
29}
30
31pub trait RApply<A, Lhs> {
32    type Output;
33
34    fn dot(&self, lhs: &Lhs) -> Self::Output;
35}
36
37/// Matrix-Vector Product Trait
38///
39/// This trait defines an interface for operators that provide matrix-vector products.
40pub trait MatVec {
41    type A: Scalar;
42
43    // Return the number of rows of the operator.
44    fn nrows(&self) -> usize;
45
46    // Return the number of columns of the operator.
47    fn ncols(&self) -> usize;
48
49    // Return the matrix vector product of an operator with a vector.
50    fn matvec(&self, mat: ArrayView1<Self::A>) -> Array1<Self::A>;
51}
52
53/// Matrix-Matrix Product Trait
54///
55/// This trait defines the application of a linear operator $A$ to a matrix X representing multiple columns.
56/// If it is not implemented then a default implementation is used based on the `MatVec` trait applied to the
57/// individual columns of X.
58pub trait MatMat: MatVec {
59    // Return the matrix-matrix product of an operator with a matrix.
60    fn matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
61        let mut output = Array2::<Self::A>::zeros((self.nrows(), mat.ncols()));
62
63        for (index, col) in mat.axis_iter(Axis(1)).enumerate() {
64            output
65                .index_axis_mut(Axis(1), index)
66                .assign(&self.matvec(col));
67        }
68
69        output
70    }
71}
72
73/// Trait describing the product of the conjugate adjoint of an operator with a vector
74///
75/// In the case that the operator is a matrix then this simply describes the action $A^Hx$,
76/// where $x$ is a vector and $A^H$ the complex conjugate adjoint of $A$.
77pub trait ConjMatVec: MatVec {
78    // If `self` is a linear operator return the product of the conjugate of `self`
79    // with a vector.
80    fn conj_matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A>;
81}
82
83/// Trait describing the action of the conjugate adjoint of an operator with a matrix
84///
85/// In the case that the operator is a matrix then this simply describes the action $A^HX$,
86/// where $X$ is another matrix and $A^H$ the complex conjugate adjoint of $A$. If this trait
87/// is not implemented then a default implementation based on the `ConjMatVec` trait is used.
88pub trait ConjMatMat: MatMat + ConjMatVec {
89    // Return the product of the complex conjugate of `self` with a given matrix.
90    fn conj_matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
91        let mut output = Array2::<Self::A>::zeros((self.ncols(), mat.ncols()));
92
93        for (index, col) in mat.axis_iter(Axis(1)).enumerate() {
94            output
95                .index_axis_mut(Axis(1), index)
96                .assign(&self.conj_matvec(col));
97        }
98
99        output
100    }
101}
102
103impl<A, S> MatVec for ArrayBase<S, Ix2>
104where
105    A: Scalar,
106    S: Data<Elem = A>,
107{
108    type A = A;
109
110    fn nrows(&self) -> usize {
111        self.nrows()
112    }
113
114    fn ncols(&self) -> usize {
115        self.ncols()
116    }
117
118    fn matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A> {
119        self.dot(&vec)
120    }
121}
122
123impl<A, S> ConjMatVec for ArrayBase<S, Ix2>
124where
125    A: Scalar,
126    S: Data<Elem = A>,
127{
128    fn conj_matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A> {
129        vec.map(|item| item.conj())
130            .dot(self)
131            .map(|item| item.conj())
132    }
133}
134
135// impl<A, S> MatMat for ArrayBase<S, Ix2>
136// where
137//     A: Scalar,
138//     S: Data<Elem = A>,
139// {
140//     fn matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
141//         self.dot(&mat)
142//     }
143// }
144
145impl<A: Scalar, T: MatVec<A=A>> MatMat for T {}
146impl<A: Scalar, T: ConjMatVec<A=A>> ConjMatMat for T {}
147
148// impl<A, S> ConjMatMat for ArrayBase<S, Ix2>
149// where
150//     A: Scalar,
151//     S: Data<Elem = A>,
152// {
153//     fn conj_matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
154//         mat.t()
155//             .map(|item| item.conj())
156//             .dot(self)
157//             .t()
158//             .map(|item| item.conj())
159//     }
160// }
161
162pub trait RelDiff {
163    type A: Scalar;
164
165    /// Return the relative Frobenius norm difference of `first` and `second`.
166    fn rel_diff_fro(
167        first: ArrayView2<Self::A>,
168        second: ArrayView2<Self::A>,
169    ) -> <<Self as RelDiff>::A as Scalar>::Real;
170
171    /// Return the relative l2 vector norm difference of `first` and `second`.
172    fn rel_diff_l2(
173        first: ArrayView1<Self::A>,
174        second: ArrayView1<Self::A>,
175    ) -> <<Self as RelDiff>::A as Scalar>::Real;
176}
177
178macro_rules! rel_diff_impl {
179    ($scalar:ty) => {
180        impl RelDiff for $scalar {
181            type A = $scalar;
182            fn rel_diff_fro(
183                first: ArrayView2<Self::A>,
184                second: ArrayView2<Self::A>,
185            ) -> <<Self as RelDiff>::A as Scalar>::Real {
186                let diff = first.to_owned() - &second;
187                diff.opnorm_fro().unwrap() / second.opnorm_fro().unwrap()
188            }
189
190            fn rel_diff_l2(
191                first: ArrayView1<Self::A>,
192                second: ArrayView1<Self::A>,
193            ) -> <<Self as RelDiff>::A as Scalar>::Real {
194                let diff = first.to_owned() - &second;
195                diff.norm_l2() / second.norm_l2()
196            }
197        }
198    };
199}
200
201rel_diff_impl!(f32);
202rel_diff_impl!(f64);
203rel_diff_impl!(c32);
204rel_diff_impl!(c64);