rusty_compression/
col_interp_decomp.rs

1//! Data structures for Column Interpolative Decomposition.
2//! 
3//! A column interpolative decomposition of a matrix $A\in\mathbb{C}^{m\times n}$ is
4//! defined as
5//! $$
6//! A\approx CZ
7//! $$
8//! with $C\in\mathbb{C}^{m\times k}$ being a matrix whose columns form a subset of the columns 
9//! of $A$, and $Z\in\mathbb{R}^{k\times m}$. The columns of $C$ are obtained from the corresponding columns of
10//! $A$ via an index vector col_ind. If col_ind\[i\] = j then the ith column of $C$ is identical to the jth column
11//! of $A$.
12
13use crate::types::Apply;
14use crate::two_sided_interp_decomp::TwoSidedID;
15use crate::qr::{LQ, LQTraits};
16use crate::row_interp_decomp::RowIDTraits;
17use ndarray::{
18    Array1, Array2, ArrayBase, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Data, Ix1, Ix2,
19};
20use crate::types::{c32, c64, Scalar, Result};
21
22/// Store a Column Interpolative Decomposition
23pub struct ColumnID<A: Scalar> {
24    /// The C matrix of the column interpolative decomposition
25    c: Array2<A>,
26    /// The Z matrix of the column interpolative decomposition
27    z: Array2<A>,
28    /// An index vector. If col_ind\[i\] = j then the ith column of
29    /// C is identical to the jth column of A.
30    col_ind: Array1<usize>,
31}
32
33/// Traits defining a column interpolative decomposition
34/// 
35// A column interpolative decomposition of a matrix $A\in\mathbb{C}^{m\times n}$ is
36// defined as
37// $$
38// A\approx CZ
39// $$
40// with $C\in\mathbb{C}^{m\times k}$ being a matrix whose columns form a subset of the columns 
41// of $A$, and $Z\in\mathbb{R}^{k\times m}$. The columns of $C$ are obtained from the corresponding columns of
42// $A$ via an index vector col_ind. If col_ind\[i\] = j then the ith column of $C$ is identical to the jth column
43// of $A$.
44pub trait ColumnIDTraits {
45    type A: Scalar;
46
47    /// Number of rows of the underlying operator
48    fn nrows(&self) -> usize {
49        self.get_c().nrows()
50    }
51
52    /// Number of columns of the underlying operator
53    fn ncols(&self) -> usize {
54        self.get_z().ncols()
55    }
56
57    /// Rank of the column interpolative decomposition
58    fn rank(&self) -> usize {
59        self.get_c().ncols()
60    }
61
62    /// Convert to a matrix
63    fn to_mat(&self) -> Array2<Self::A> {
64        self.get_c().dot(&self.get_z())
65    }
66
67    /// Return the C matrix
68    fn get_c(&self) -> ArrayView2<Self::A>;
69
70    /// Return the Z matrix
71    fn get_z(&self) -> ArrayView2<Self::A>;
72
73    /// Return the index vector
74    fn get_col_ind(&self) -> ArrayView1<usize>;
75
76    fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A>;
77    fn get_z_mut(&mut self) -> ArrayViewMut2<Self::A>;
78    fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize>;
79
80    /// Return a column interpolative decomposition from given component matrices $C$ and
81    /// $Z$ and index array col_ind
82    fn new(c: Array2<Self::A>, z: Array2<Self::A>, col_ind: Array1<usize>) -> Self;
83
84    /// Convert the column interpolative decomposition into a two sided interpolative decomposition
85    fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>>;
86}
87
88macro_rules! impl_col_id {
89    ($scalar:ty) => {
90        impl ColumnIDTraits for ColumnID<$scalar> {
91            type A = $scalar;
92            fn get_c(&self) -> ArrayView2<Self::A> {
93                self.c.view()
94            }
95            fn get_z(&self) -> ArrayView2<Self::A> {
96                self.z.view()
97            }
98
99            fn get_col_ind(&self) -> ArrayView1<usize> {
100                self.col_ind.view()
101            }
102
103            fn get_c_mut(&mut self) -> ArrayViewMut2<Self::A> {
104                self.c.view_mut()
105            }
106            fn get_z_mut(&mut self) -> ArrayViewMut2<Self::A> {
107                self.z.view_mut()
108            }
109            fn get_col_ind_mut(&mut self) -> ArrayViewMut1<usize> {
110                self.col_ind.view_mut()
111            }
112
113            fn new(c: Array2<Self::A>, z: Array2<Self::A>, col_ind: Array1<usize>) -> Self {
114                ColumnID::<$scalar> { c, z, col_ind }
115            }
116            fn two_sided_id(&self) -> Result<TwoSidedID<Self::A>> {
117                let row_id = LQ::<$scalar>::compute_from(self.c.view())?.row_id()?;
118                Ok(TwoSidedID {
119                    c: row_id.get_x().into_owned(),
120                    x: row_id.get_r().into_owned(),
121                    r: self.get_z().into_owned(),
122                    row_ind: row_id.get_row_ind().into_owned(),
123                    col_ind: self.col_ind.to_owned(),
124
125                })
126
127
128
129
130            }
131
132        }
133
134        impl<S> Apply<$scalar, ArrayBase<S, Ix1>> for ColumnID<$scalar>
135        where
136            S: Data<Elem = $scalar>,
137        {
138            type Output = Array1<$scalar>;
139
140            fn dot(&self, rhs: &ArrayBase<S, Ix1>) -> Self::Output {
141                self.c.dot(&self.z.dot(rhs))
142            }
143        }
144
145        impl<S> Apply<$scalar, ArrayBase<S, Ix2>> for ColumnID<$scalar>
146        where
147            S: Data<Elem = $scalar>,
148        {
149            type Output = Array2<$scalar>;
150
151            fn dot(&self, rhs: &ArrayBase<S, Ix2>) -> Self::Output {
152                self.c.dot(&self.z.dot(rhs))
153            }
154        }
155    };
156}
157
158impl_col_id!(f32);
159impl_col_id!(f64);
160impl_col_id!(c32);
161impl_col_id!(c64);
162
163#[cfg(test)]
164mod tests {
165
166    use crate::permutation::ApplyPermutationToMatrix;
167    use crate::CompressionType;
168    use crate::permutation::MatrixPermutationMode;
169    use crate::qr::{QRTraits, QR};
170    use crate::col_interp_decomp::ColumnIDTraits;
171    use crate::two_sided_interp_decomp::TwoSidedIDTraits;
172    use crate::random_matrix::RandomMatrix;
173    use crate::types::RelDiff;
174    use crate::types::Scalar;
175
176    macro_rules! id_compression_tests {
177
178        ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
179
180            $(
181
182        #[test]
183        fn $name() {
184            let m = $dim.0;
185            let n = $dim.1;
186
187            let sigma_max = 1.0;
188            let sigma_min = 1E-10;
189            let mut rng = rand::thread_rng();
190            let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
191
192            let qr = QR::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
193            let rank = qr.rank();
194            let two_sided_id = qr.column_id().unwrap().two_sided_id().unwrap();
195
196            // Compare with original matrix
197
198            assert!(<$scalar>::rel_diff_fro(two_sided_id.to_mat().view(), mat.view()) < 5.0 * $tol);
199
200            // Now compare the individual columns to make sure that the id basis columns
201            // agree with the corresponding matrix columns.
202
203            let mat_permuted = mat.apply_permutation(two_sided_id.row_ind.view(), MatrixPermutationMode::ROW).
204                apply_permutation(two_sided_id.col_ind.view(), MatrixPermutationMode::COL);
205
206            // Assert that the x matrix in the two sided id is squared with correct dimension.
207
208            assert!(two_sided_id.x.nrows() == two_sided_id.x.ncols());
209            assert!(two_sided_id.x.nrows() == rank);
210
211            // Now compare with the original matrix.
212
213            for row_index in 0..rank {
214                for col_index in 0..rank {
215                    let tmp = (two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs() / mat_permuted[[row_index, col_index]].abs();
216                    println!("Rel Error {}", tmp);
217                    //if tmp >= 5.0 * $tol {
218                        //println!(" Rel Error {}", tmp);
219                    //}
220
221                    assert!((two_sided_id.x[[row_index, col_index]] - mat_permuted[[row_index, col_index]]).abs()
222                            < 10.0 * $tol * mat_permuted[[row_index, col_index]].abs())
223                }
224            }
225        }
226
227            )*
228
229        }
230    }
231
232    id_compression_tests! {
233        test_two_sided_from_col_id_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
234        test_two_sided_from_col_id_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
235        test_two_sided_from_col_id_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
236        test_two_sided_from_col_id_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
237        test_two_sided_from_col_id_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
238        test_two_sided_from_col_id_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
239        test_two_sided_from_col_id_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
240        test_two_sided_from_col_id_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
241    }
242}