rusty_compression/
types.rs1use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
4use ndarray_linalg::error::LinalgError;
5use ndarray_linalg::Norm;
6use ndarray_linalg::OperationNorm;
7use thiserror::Error;
8
9pub use ndarray_linalg::{c32, c64, Scalar};
10
11#[derive(Error, Debug)]
12pub enum RustyCompressionError {
13 #[error("Lapack Error")]
14 LinalgError(LinalgError),
15 #[error("Could not compress to desired tolerance")]
16 CompressionError,
17 #[error("Incompatible memory layout")]
18 LayoutError,
19 #[error("Pivoted QR failed")]
20 PivotedQRError,
21}
22
23pub type Result<T> = std::result::Result<T, RustyCompressionError>;
24
25pub trait Apply<A, Lhs> {
26 type Output;
27
28 fn dot(&self, lhs: &Lhs) -> Self::Output;
29}
30
31pub trait RApply<A, Lhs> {
32 type Output;
33
34 fn dot(&self, lhs: &Lhs) -> Self::Output;
35}
36
37pub trait MatVec {
41 type A: Scalar;
42
43 fn nrows(&self) -> usize;
45
46 fn ncols(&self) -> usize;
48
49 fn matvec(&self, mat: ArrayView1<Self::A>) -> Array1<Self::A>;
51}
52
53pub trait MatMat: MatVec {
59 fn matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
61 let mut output = Array2::<Self::A>::zeros((self.nrows(), mat.ncols()));
62
63 for (index, col) in mat.axis_iter(Axis(1)).enumerate() {
64 output
65 .index_axis_mut(Axis(1), index)
66 .assign(&self.matvec(col));
67 }
68
69 output
70 }
71}
72
73pub trait ConjMatVec: MatVec {
78 fn conj_matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A>;
81}
82
83pub trait ConjMatMat: MatMat + ConjMatVec {
89 fn conj_matmat(&self, mat: ArrayView2<Self::A>) -> Array2<Self::A> {
91 let mut output = Array2::<Self::A>::zeros((self.ncols(), mat.ncols()));
92
93 for (index, col) in mat.axis_iter(Axis(1)).enumerate() {
94 output
95 .index_axis_mut(Axis(1), index)
96 .assign(&self.conj_matvec(col));
97 }
98
99 output
100 }
101}
102
103impl<A, S> MatVec for ArrayBase<S, Ix2>
104where
105 A: Scalar,
106 S: Data<Elem = A>,
107{
108 type A = A;
109
110 fn nrows(&self) -> usize {
111 self.nrows()
112 }
113
114 fn ncols(&self) -> usize {
115 self.ncols()
116 }
117
118 fn matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A> {
119 self.dot(&vec)
120 }
121}
122
123impl<A, S> ConjMatVec for ArrayBase<S, Ix2>
124where
125 A: Scalar,
126 S: Data<Elem = A>,
127{
128 fn conj_matvec(&self, vec: ArrayView1<Self::A>) -> Array1<Self::A> {
129 vec.map(|item| item.conj())
130 .dot(self)
131 .map(|item| item.conj())
132 }
133}
134
135impl<A: Scalar, T: MatVec<A=A>> MatMat for T {}
146impl<A: Scalar, T: ConjMatVec<A=A>> ConjMatMat for T {}
147
148pub trait RelDiff {
163 type A: Scalar;
164
165 fn rel_diff_fro(
167 first: ArrayView2<Self::A>,
168 second: ArrayView2<Self::A>,
169 ) -> <<Self as RelDiff>::A as Scalar>::Real;
170
171 fn rel_diff_l2(
173 first: ArrayView1<Self::A>,
174 second: ArrayView1<Self::A>,
175 ) -> <<Self as RelDiff>::A as Scalar>::Real;
176}
177
178macro_rules! rel_diff_impl {
179 ($scalar:ty) => {
180 impl RelDiff for $scalar {
181 type A = $scalar;
182 fn rel_diff_fro(
183 first: ArrayView2<Self::A>,
184 second: ArrayView2<Self::A>,
185 ) -> <<Self as RelDiff>::A as Scalar>::Real {
186 let diff = first.to_owned() - &second;
187 diff.opnorm_fro().unwrap() / second.opnorm_fro().unwrap()
188 }
189
190 fn rel_diff_l2(
191 first: ArrayView1<Self::A>,
192 second: ArrayView1<Self::A>,
193 ) -> <<Self as RelDiff>::A as Scalar>::Real {
194 let diff = first.to_owned() - &second;
195 diff.norm_l2() / second.norm_l2()
196 }
197 }
198 };
199}
200
201rel_diff_impl!(f32);
202rel_diff_impl!(f64);
203rel_diff_impl!(c32);
204rel_diff_impl!(c64);