rusty_compression/
row_interp_decomp.rs1use crate::types::Apply;
16use crate::two_sided_interp_decomp::TwoSidedID;
17use crate::qr::{QR, QRTraits};
18use crate::col_interp_decomp::ColumnIDTraits;
19use ndarray::{
20 Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
21};
22use crate::types::{c32, c64, Scalar, Result};
23
24pub struct RowID<A: Scalar> {
26 x: Array2<A>,
28 r: Array2<A>,
30 row_ind: Array1<usize>,
33}
34
35pub trait RowIDTraits {
47 type A: Scalar;
48
49 fn nrows(&self) -> usize {
51 self.get_x().nrows()
52 }
53
54 fn ncols(&self) -> usize {
56 self.get_r().ncols()
57 }
58
59 fn rank(&self) -> usize {
61 self.get_r().nrows()
62 }
63
64 fn to_mat(&self) -> Array2<Self::A> {
66 self.get_x().dot(&self.get_r())
67 }
68
69 fn get_x(&self) -> ArrayView2<Self::A>;
71
72 fn get_r(&self) -> ArrayView2<Self::A>;
74
75 fn get_row_ind(&self) -> ArrayView1<usize>;
77
78 fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A>;
79 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
80 fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize>;
81
82 fn new(x: Array2<Self::A>, r: Array2<Self::A>, row_ind: Array1<usize>) -> Self;
85
86 fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>>;
88
89}
90
91macro_rules! impl_row_id {
92 ($scalar:ty) => {
93 impl RowIDTraits for RowID<$scalar> {
94 type A = $scalar;
95 fn get_x(&self) -> ArrayView2<Self::A> {
96 self.x.view()
97 }
98 fn get_r(&self) -> ArrayView2<Self::A> {
99 self.r.view()
100 }
101
102 fn get_row_ind(&self) -> ArrayView1<usize> {
103 self.row_ind.view()
104 }
105
106 fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A> {
107 self.x.view_mut()
108 }
109 fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
110 self.r.view_mut()
111 }
112 fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize> {
113 self.row_ind.view_mut()
114 }
115
116 fn new(x: Array2<Self::A>, r: Array2<Self::A>, row_ind: Array1<usize>) -> Self {
117 RowID::<$scalar> { x, r, row_ind }
118 }
119
120 fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>> {
121 let col_id = QR::<$scalar>::compute_from(self.r.view())?.column_id()?;
122 Ok(TwoSidedID {
123 c: self.x.to_owned(),
124 x: col_id.get_c().into_owned(),
125 r: col_id.get_z().into_owned(),
126 row_ind: self.row_ind.to_owned(),
127 col_ind: col_id.get_col_ind().into_owned(),
128
129 })
130 }
131
132 }
133
134 impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for RowID<$scalar>
135 where
136 S: Data<Elem = $scalar>,
137 {
138 type Output = Array1<$scalar>;
139
140 fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
141 self.x.dot(&self.r.dot(rhs))
142 }
143 }
144
145 impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for RowID<$scalar>
146 where
147 S: Data<Elem = $scalar>,
148 {
149 type Output = Array2<$scalar>;
150
151 fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
152 self.x.dot(&self.r.dot(rhs))
153 }
154 }
155 };
156}
157
158impl_row_id!(f32);
159impl_row_id!(f64);
160impl_row_id!(c32);
161impl_row_id!(c64);
162
163#[cfg(test)]
164mod tests {
165
166 use crate::permutation::ApplyPermutationToMatrix;
167 use crate::CompressionType;
168 use crate::permutation::MatrixPermutationMode;
169 use crate::qr::{LQTraits, LQ};
170 use crate::row_interp_decomp::RowIDTraits;
171 use crate::two_sided_interp_decomp::TwoSidedIDTraits;
172 use crate::random_matrix::RandomMatrix;
173 use crate::types::RelDiff;
174 use crate::types::Scalar;
175
176 macro_rules! id_compression_tests {
177
178 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
179
180 $(
181
182 #[test]
183 fn $name() {
184 let m = $dim.0;
185 let n = $dim.1;
186
187 let sigma_max = 1.0;
188 let sigma_min = 1E-10;
189 let mut rng = rand::thread_rng();
190 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
191
192 let lq = LQ::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
193 let rank = lq.rank();
194 let two_sided_id = lq.row_id().unwrap().two_sided_id().unwrap();
195
196 assert!(<$scalar>::rel_diff_fro(two_sided_id.to_mat().view(), mat.view()) < 5.0 * $tol);
199
200 let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
204 apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
205
206 assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
209 assert!(two_sided_id.x.nrows() == rank);
210
211 for row_index in 0..rank {
214 for col_index in 0..rank {
215 assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
216 < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
217 }
218 }
219 }
220
221 )*
222
223 }
224 }
225
226 id_compression_tests! {
227 test_two_sided_from_row_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
228 test_two_sided_from_row_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
229 test_two_sided_from_row_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
230 test_two_sided_from_row_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
231 test_two_sided_from_row_id_compression_by_tol_f32_thick: f32, (50, 100), 5E-4,
232 test_two_sided_from_row_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
233 test_two_sided_from_row_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
234 test_two_sided_from_row_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
235 }
236}