rusty_compression/
row_interp_decomp.rs

1//! Data structure for Row Interpolative Decomposition
2//! 
3//! A row interpolative decomposition of a matrix $A\in\mathbb{C}^{m\times n}$ is
4//! defined as
5//! $$
6//! A\approx XR
7//! $$
8//! with $R\in\mathbb{C}^{k\times n}$ being a matrix whose rows form a subset of the rows 
9//! of $A$, and $X\in\mathbb{R}^{m\times k}$. The rows of $R$ are obtained from the corresponding rows of
10//! $A$ via an index vector row_ind. If row_ind\[i\] = j then the ith row of $R$ is identical to the jth row
11//! of $A$.
12
13
14
15use crate::types::Apply;
16use crate::two_sided_interp_decomp::TwoSidedID;
17use crate::qr::{QR, QRTraits};
18use crate::col_interp_decomp::ColumnIDTraits;
19use ndarray::{
20    Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
21};
22use crate::types::{c32, c64, Scalar, Result};
23
24/// Store a Row Interpolative Decomposition
25pub struct RowID<A: Scalar> {
26    /// The X matrix of the row interpolative decomposition
27    x: Array2<A>,
28    /// The R matrix of the row interpolative decomposition
29    r: Array2<A>,
30    /// An index vector. If row_ind\[i\] = j then the ith row of
31    /// R is identical to the jth row of A.
32    row_ind: Array1<usize>,
33}
34
35/// Traits defining a row interpolative decomposition
36/// 
37/// A row interpolative decomposition of a matrix $A\in\mathbb{C}^{m\times n}$ is
38// defined as
39// $$
40// A\approx XR
41// $$
42// with $R\in\mathbb{C}^{k\times n}$ being a matrix whose rows form a subset of the rows 
43// of $A$, and $X\in\mathbb{R}^{m\times k}$. The rows of $R$ are obtained from the corresponding rows of
44// $A$ via an index vector row_ind. If row_ind\[i\] = j then the ith row of $R$ is identical to the jth row
45// of $A$.
46pub trait RowIDTraits {
47    type A: Scalar;
48
49    /// Number of rows of the underlying operator
50    fn nrows(&self) -> usize {
51        self.get_x().nrows()
52    }
53
54    /// Number of columns of the underlying operator
55    fn ncols(&self) -> usize {
56        self.get_r().ncols()
57    }
58
59    /// Rank of the row interpolative decomposition
60    fn rank(&self) -> usize {
61        self.get_r().nrows()
62    }
63
64    /// Convert to a matrix
65    fn to_mat(&self) -> Array2<Self::A> {
66        self.get_x().dot(&self.get_r())
67    }
68
69    /// Return the X matrix
70    fn get_x(&self) -> ArrayView2<Self::A>;
71
72    /// Return the R matrix
73    fn get_r(&self) -> ArrayView2<Self::A>;
74
75    /// Return the index vector
76    fn get_row_ind(&self) -> ArrayView1<usize>;
77
78    fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A>;
79    fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A>;
80    fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize>;
81
82    /// Return a row interpolative decomposition from given component matrices $X$ and
83    /// $R$ and index array row_ind.
84    fn new(x: Array2<Self::A>, r: Array2<Self::A>, row_ind: Array1<usize>) -> Self;
85
86    /// Convert the row interpolative decomposition into a two sided interpolative decomposition
87    fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>>;
88
89}
90
91macro_rules! impl_row_id {
92    ($scalar:ty) => {
93        impl RowIDTraits for RowID<$scalar> {
94            type A = $scalar;
95            fn get_x(&self) -> ArrayView2<Self::A> {
96                self.x.view()
97            }
98            fn get_r(&self) -> ArrayView2<Self::A> {
99                self.r.view()
100            }
101
102            fn get_row_ind(&self) -> ArrayView1<usize> {
103                self.row_ind.view()
104            }
105
106            fn get_x_mut(&mut self) -> ArrayViewMut2<Self::A> {
107                self.x.view_mut()
108            }
109            fn get_r_mut(&mut self) -> ArrayViewMut2<Self::A> {
110                self.r.view_mut()
111            }
112            fn get_row_ind_mut(&mut self) -> ArrayViewMut1<usize> {
113                self.row_ind.view_mut()
114            }
115
116            fn new(x: Array2<Self::A>, r: Array2<Self::A>, row_ind: Array1<usize>) -> Self {
117                RowID::<$scalar> { x, r, row_ind }
118            }
119
120            fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>> {
121                let col_id = QR::<$scalar>::compute_from(self.r.view())?.column_id()?;
122                Ok(TwoSidedID {
123                    c: self.x.to_owned(),
124                    x: col_id.get_c().into_owned(),
125                    r: col_id.get_z().into_owned(),
126                    row_ind: self.row_ind.to_owned(),
127                    col_ind: col_id.get_col_ind().into_owned(),
128
129                })
130            }
131
132        }
133
134        impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for RowID<$scalar>
135        where
136            S: Data<Elem = $scalar>,
137        {
138            type Output = Array1<$scalar>;
139
140            fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
141                self.x.dot(&self.r.dot(rhs))
142            }
143        }
144
145        impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for RowID<$scalar>
146        where
147            S: Data<Elem = $scalar>,
148        {
149            type Output = Array2<$scalar>;
150
151            fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
152                self.x.dot(&self.r.dot(rhs))
153            }
154        }
155    };
156}
157
158impl_row_id!(f32);
159impl_row_id!(f64);
160impl_row_id!(c32);
161impl_row_id!(c64);
162
163#[cfg(test)]
164mod tests {
165
166    use crate::permutation::ApplyPermutationToMatrix;
167    use crate::CompressionType;
168    use crate::permutation::MatrixPermutationMode;
169    use crate::qr::{LQTraits, LQ};
170    use crate::row_interp_decomp::RowIDTraits;
171    use crate::two_sided_interp_decomp::TwoSidedIDTraits;
172    use crate::random_matrix::RandomMatrix;
173    use crate::types::RelDiff;
174    use crate::types::Scalar;
175
176    macro_rules! id_compression_tests {
177
178        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
179
180            $(
181
182        #[test]
183        fn $name() {
184            let m = $dim.0;
185            let n = $dim.1;
186
187            let sigma_max = 1.0;
188            let sigma_min = 1E-10;
189            let mut rng = rand::thread_rng();
190            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
191
192            let lq = LQ::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
193            let rank = lq.rank();
194            let two_sided_id = lq.row_id().unwrap().two_sided_id().unwrap();
195
196            // Compare with original matrix
197
198            assert!(<$scalar>::rel_diff_fro(two_sided_id.to_mat().view(), mat.view()) < 5.0 * $tol);
199
200            // Now compare the individual columns to make sure that the id basis columns
201            // agree with the corresponding matrix columns.
202
203            let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
204                apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
205
206            // Assert that the x matrix in the two sided id is squared with correct dimension.
207
208            assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
209            assert!(two_sided_id.x.nrows() == rank);
210
211            // Now compare with the original matrix.
212
213            for row_index in 0..rank {
214                for col_index in 0..rank {
215                    assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
216                            < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
217                }
218            }
219        }
220
221            )*
222
223        }
224    }
225
226    id_compression_tests! {
227        test_two_sided_from_row_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
228        test_two_sided_from_row_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
229        test_two_sided_from_row_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
230        test_two_sided_from_row_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
231        test_two_sided_from_row_id_compression_by_tol_f32_thick: f32, (50, 100), 5E-4,
232        test_two_sided_from_row_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
233        test_two_sided_from_row_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
234        test_two_sided_from_row_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
235    }
236}