1use crate::types::Apply;
14use crate::two_sided_interp_decomp::TwoSidedID;
15use crate::qr::{LQ, LQTraits};
16use crate::row_interp_decomp::RowIDTraits;
17use ndarray::{
18 Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
19};
20use crate::types::{c32, c64, Scalar, Result};
21
22pub struct ColumnID<A: Scalar> {
24 c: Array2<A>,
26 z: Array2<A>,
28 col_ind: Array1<usize>,
31}
32
33pub trait ColumnIDTraits {
45 type A: Scalar;
46
47 fn nrows(&self) -> usize {
49 self.get_c().nrows()
50 }
51
52 fn ncols(&self) -> usize {
54 self.get_z().ncols()
55 }
56
57 fn rank(&self) -> usize {
59 self.get_c().ncols()
60 }
61
62 fn to_mat(&self) -> Array2<Self::A> {
64 self.get_c().dot(&self.get_z())
65 }
66
67 fn get_c(&self) -> ArrayView2<Self::A>;
69
70 fn get_z(&self) -> ArrayView2<Self::A>;
72
73 fn get_col_ind(&self) -> ArrayView1<usize>;
75
76 fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A>;
77 fn get_z_mut(&mut self) -> ArrayViewMut2<Self::A>;
78 fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize>;
79
80 fn new(c: Array2<Self::A>, z: Array2<Self::A>, col_ind: Array1<usize>) -> Self;
83
84 fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>>;
86}
87
88macro_rules! impl_col_id {
89 ($scalar:ty) => {
90 impl ColumnIDTraits for ColumnID<$scalar> {
91 type A = $scalar;
92 fn get_c(&self) -> ArrayView2<Self::A> {
93 self.c.view()
94 }
95 fn get_z(&self) -> ArrayView2<Self::A> {
96 self.z.view()
97 }
98
99 fn get_col_ind(&self) -> ArrayView1<usize> {
100 self.col_ind.view()
101 }
102
103 fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A> {
104 self.c.view_mut()
105 }
106 fn get_z_mut(&mut self) -> ArrayViewMut2<Self::A> {
107 self.z.view_mut()
108 }
109 fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize> {
110 self.col_ind.view_mut()
111 }
112
113 fn new(c: Array2<Self::A>, z: Array2<Self::A>, col_ind: Array1<usize>) -> Self {
114 ColumnID::<$scalar> { c, z, col_ind }
115 }
116 fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>> {
117 let row_id = LQ::<$scalar>::compute_from(self.c.view())?.row_id()?;
118 Ok(TwoSidedID {
119 c: row_id.get_x().into_owned(),
120 x: row_id.get_r().into_owned(),
121 r: self.get_z().into_owned(),
122 row_ind: row_id.get_row_ind().into_owned(),
123 col_ind: self.col_ind.to_owned(),
124
125 })
126
127
128
129
130 }
131
132 }
133
134 impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for ColumnID<$scalar>
135 where
136 S: Data<Elem = $scalar>,
137 {
138 type Output = Array1<$scalar>;
139
140 fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
141 self.c.dot(&self.z.dot(rhs))
142 }
143 }
144
145 impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for ColumnID<$scalar>
146 where
147 S: Data<Elem = $scalar>,
148 {
149 type Output = Array2<$scalar>;
150
151 fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
152 self.c.dot(&self.z.dot(rhs))
153 }
154 }
155 };
156}
157
158impl_col_id!(f32);
159impl_col_id!(f64);
160impl_col_id!(c32);
161impl_col_id!(c64);
162
163#[cfg(test)]
164mod tests {
165
166 use crate::permutation::ApplyPermutationToMatrix;
167 use crate::CompressionType;
168 use crate::permutation::MatrixPermutationMode;
169 use crate::qr::{QRTraits, QR};
170 use crate::col_interp_decomp::ColumnIDTraits;
171 use crate::two_sided_interp_decomp::TwoSidedIDTraits;
172 use crate::random_matrix::RandomMatrix;
173 use crate::types::RelDiff;
174 use crate::types::Scalar;
175
176 macro_rules! id_compression_tests {
177
178 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
179
180 $(
181
182 #[test]
183 fn $name() {
184 let m = $dim.0;
185 let n = $dim.1;
186
187 let sigma_max = 1.0;
188 let sigma_min = 1E-10;
189 let mut rng = rand::thread_rng();
190 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
191
192 let qr = QR::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
193 let rank = qr.rank();
194 let two_sided_id = qr.column_id().unwrap().two_sided_id().unwrap();
195
196 assert!(<$scalar>::rel_diff_fro(two_sided_id.to_mat().view(), mat.view()) < 5.0 * $tol);
199
200 let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
204 apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
205
206 assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
209 assert!(two_sided_id.x.nrows() == rank);
210
211 for row_index in 0..rank {
214 for col_index in 0..rank {
215 let tmp = (two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs() / mat_permuted[[row_index, col_index]].abs();
216 println!("Rel Error {}", tmp);
217 assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
222 < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
223 }
224 }
225 }
226
227 )*
228
229 }
230 }
231
232 id_compression_tests! {
233 test_two_sided_from_col_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
234 test_two_sided_from_col_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
235 test_two_sided_from_col_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
236 test_two_sided_from_col_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
237 test_two_sided_from_col_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
238 test_two_sided_from_col_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
239 test_two_sided_from_col_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
240 test_two_sided_from_col_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
241 }
242}