1use crate::qr::{QRTraits, QR};
4use crate::CompressionType;
5use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis, Zip};
6use num::ToPrimitive;
7use crate::types::Result;
8use crate::types::RustyCompressionError;
9use crate::types::{c32,c64, ConjMatMat, Scalar};
10
11pub struct SVD<A: Scalar> {
14 pub u: Array2<A>,
16 pub s: Array1<A::Real>,
18 pub vt: Array2<A>,
20}
21
22pub trait SVDTraits {
24 type A: Scalar;
25
26 fn nrows(&self) -> usize {
28 self.get_u().nrows()
29 }
30
31 fn ncols(&self) -> usize {
33 self.get_vt().ncols()
34 }
35
36 fn rank(&self) -> usize {
38 self.get_u().ncols()
39 }
40
41 fn to_mat(&self) -> Array2<Self::A> {
43 let mut scaled_vt =
44 Array2::<Self::A>::zeros((self.get_vt().nrows(), self.get_vt().ncols()));
45 scaled_vt.assign(&self.get_vt());
46
47 Zip::from(scaled_vt.axis_iter_mut(Axis(0)))
48 .and(self.get_s().view())
49 .for_each(|mut row, &s_elem| {
50 row.map_inplace(|item| *item *= <Self::A as Scalar>::from_real(s_elem))
51 });
52
53 self.get_u().dot(&scaled_vt)
54 }
55
56 fn to_qr(self) -> Result<QR<Self::A>>;
58
59 fn compress(&self, compression_type: CompressionType) -> Result<SVD<Self::A>> {
61 match compression_type {
62 CompressionType::ADAPTIVE(tol) => self.compress_svd_tolerance(tol),
63 CompressionType::RANK(rank) => self.compress_svd_rank(rank),
64 }
65 }
66
67 fn compress_svd_rank(&self, mut max_rank: usize) -> Result<SVD<Self::A>> {
69 let (u, s, vt) = (self.get_u(), self.get_s(), self.get_vt());
70
71 if max_rank > s.len() {
72 max_rank = s.len()
73 }
74
75 let u = u.slice(s![.., 0..max_rank]);
76 let s = s.slice(s![0..max_rank]);
77 let vt = vt.slice(s![0..max_rank, ..]);
78
79 Ok(SVD {
80 u: u.into_owned(),
81 s: s.into_owned(),
82 vt: vt.into_owned(),
83 })
84 }
85
86 fn compress_svd_tolerance(&self, tol: f64) -> Result<SVD<Self::A>> {
88 assert!((tol < 1.0) && (0.0 <= tol), "Require 0 <= tol < 1.0");
89
90 let first_val = self.get_s()[0];
91
92 let pos = self
93 .get_s()
94 .iter()
95 .position(|&item| (item / first_val).to_f64().unwrap() < tol);
96
97 match pos {
98 Some(index) => self.compress_svd_rank(index),
99 None => Err(RustyCompressionError::CompressionError),
100 }
101 }
102
103 fn compute_from(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>>;
104
105 fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
111 range: ArrayView2<Self::A>,
112 op: &Op,
113 ) -> Result<SVD<Self::A>>;
114
115 fn get_u(&self) -> ArrayView2<Self::A>;
116 fn get_s(&self) -> ArrayView1<<Self::A as Scalar>::Real>;
117 fn get_vt(&self) -> ArrayView2<Self::A>;
118
119 fn get_u_mut(&mut self) -> ArrayViewMut2<Self::A>;
120 fn get_s_mut(&mut self) -> ArrayViewMut1<<Self::A as Scalar>::Real>;
121 fn get_vt_mut(&mut self) -> ArrayViewMut2<Self::A>;
122}
123
124macro_rules! svd_impl {
125 ($scalar:ty) => {
126 impl SVDTraits for SVD<$scalar> {
127 type A = $scalar;
128
129 fn get_u(&self) -> ArrayView2<Self::A> {
130 self.u.view()
131 }
132
133 fn get_s(&self) -> ArrayView1<<Self::A as Scalar>::Real> {
134 self.s.view()
135 }
136 fn get_vt(&self) -> ArrayView2<Self::A> {
137 self.vt.view()
138 }
139
140 fn get_u_mut(&mut self) -> ArrayViewMut2<Self::A> {
141 self.u.view_mut()
142 }
143 fn get_s_mut(&mut self) -> ArrayViewMut1<<Self::A as Scalar>::Real> {
144 self.s.view_mut()
145 }
146 fn get_vt_mut(&mut self) -> ArrayViewMut2<Self::A> {
147 self.vt.view_mut()
148 }
149
150 fn to_qr(self) -> Result<QR<Self::A>> {
151 let (u, s, mut vt) = (self.u, self.s, self.vt);
152
153 Zip::from(vt.axis_iter_mut(Axis(0)))
154 .and(s.view())
155 .for_each(|mut row, &s_elem| {
156 row.map_inplace(|item| *item *= <Self::A as Scalar>::from_real(s_elem))
157 });
158
159 let mut qr = QR::<$scalar>::compute_from(vt.view())?;
160 qr.q = u.dot(&qr.q);
161
162 Ok(qr)
163 }
164
165 fn compute_from(arr: ArrayView2<Self::A>) -> Result<SVD<Self::A>> {
166 use crate::compute_svd::ComputeSVD;
167
168 <$scalar>::compute_svd(arr)
169 }
170
171 fn compute_from_range_estimate<Op: ConjMatMat<A = Self::A>>(
172 range: ArrayView2<Self::A>,
173 op: &Op,
174 ) -> Result<SVD<Self::A>> {
175 let b = op.conj_matmat(range).t().map(|item| item.conj());
176 let svd = SVD::<$scalar>::compute_from(b.view())?;
177
178 Ok(SVD {
179 u: range.dot(&svd.u),
180 s: svd.get_s().into_owned(),
181 vt: svd.get_vt().into_owned(),
182 })
183 }
184 }
185 };
186}
187
188svd_impl!(f32);
189svd_impl!(f64);
190svd_impl!(c32);
191svd_impl!(c64);
192
193#[cfg(test)]
194mod tests {
195
196 use super::*;
197 use crate::types::RelDiff;
198 use crate::random_matrix::RandomMatrix;
199 use crate::CompressionType;
200 use ndarray::Axis;
201 use ndarray_linalg::OperationNorm;
202
203 macro_rules! svd_to_qr_tests {
204 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
205 $(
206 #[test]
207 fn $name() {
208 let m = $dim.0;
209 let n = $dim.1;
210
211 let mut rng = rand::thread_rng();
212 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), 1.0, 1E-10, &mut rng);
213
214 let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap();
215
216 let actual = svd.to_qr().unwrap().to_mat();
218
219 assert!(<$scalar>::rel_diff_fro(actual.view(), mat.view()) < $tol);
220
221 assert!(
222 (actual - mat.view()).opnorm_fro().unwrap() / mat.opnorm_fro().unwrap() < $tol
223 );
224 }
225 )*
226 };
227 }
228
229 macro_rules! svd_compression_by_rank_tests {
230
231 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
232
233 $(
234
235 #[test]
236 fn $name() {
237 let m = $dim.0;
238 let n = $dim.1;
239 let rank: usize = 20;
240
241 let sigma_max = 1.0;
242 let sigma_min = 1E-10;
243 let mut rng = rand::thread_rng();
244 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
245
246 let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::RANK(rank)).unwrap();
247
248 assert!(svd.u.len_of(Axis(1)) == rank);
251 assert!(svd.vt.len_of(Axis(0)) == rank);
252
253 assert!(<$scalar>::rel_diff_fro(svd.to_mat().view(), mat.view()) < $tol);
254 }
255
256 )*
257
258 }
259 }
260
261 macro_rules! svd_compression_by_tol_tests {
262
263 ($($name:ident: $scalar:ty, $dim:expr, $tol:expr,)*) => {
264
265 $(
266
267 #[test]
268 fn $name() {
269 let m = $dim.0;
270 let n = $dim.1;
271
272 let sigma_max = 1.0;
273 let sigma_min = 1E-10;
274 let mut rng = rand::thread_rng();
275 let mat = <$scalar>::random_approximate_low_rank_matrix((m, n), sigma_max, sigma_min, &mut rng);
276
277 let svd = SVD::<$scalar>::compute_from(mat.view()).unwrap().compress(CompressionType::ADAPTIVE($tol)).unwrap();
278
279 assert!(<$scalar>::rel_diff_fro(svd.to_mat().view(), mat.view()) < $tol);
282 }
283
284 )*
285
286 }
287 }
288
289 svd_to_qr_tests! {
290 test_svd_to_qr_f32_thin: f32, (100, 50), 1E-5,
291 test_svd_to_qr_c32_thin: ndarray_linalg::c32, (100, 50), 1E-5,
292 test_svd_to_qr_f64_thin: f64, (100, 50), 1E-12,
293 test_svd_to_qr_c64_thin: ndarray_linalg::c64, (100, 50), 1E-12,
294 test_svd_to_qr_f32_thick: f32, (50, 100), 1E-5,
295 test_svd_to_qr_c32_thick: ndarray_linalg::c32, (50, 100), 1E-5,
296 test_svd_to_qr_f64_thick: f64, (50, 100), 1E-12,
297 test_svd_to_qr_c64_thick: ndarray_linalg::c64, (50, 100), 1E-12,
298 }
299
300 svd_compression_by_rank_tests! {
301 test_svd_compression_by_rank_f32_thin: f32, (100, 50), 1E-4,
302 test_svd_compression_by_rank_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
303 test_svd_compression_by_rank_f64_thin: f64, (100, 50), 1E-4,
304 test_svd_compression_by_rank_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
305 test_svd_compression_by_rank_f32_thick: f32, (50, 100), 1E-4,
306 test_svd_compression_by_rank_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
307 test_svd_compression_by_rank_f64_thick: f64, (50, 100), 1E-4,
308 test_svd_compression_by_rank_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
309 }
310
311 svd_compression_by_tol_tests! {
312 test_svd_compression_by_tol_f32_thin: f32, (100, 50), 1E-4,
313 test_svd_compression_by_tol_c32_thin: ndarray_linalg::c32, (100, 50), 1E-4,
314 test_svd_compression_by_tol_f64_thin: f64, (100, 50), 1E-4,
315 test_svd_compression_by_tol_c64_thin: ndarray_linalg::c64, (100, 50), 1E-4,
316 test_svd_compression_by_tol_f32_thick: f32, (50, 100), 1E-4,
317 test_svd_compression_by_tol_c32_thick: ndarray_linalg::c32, (50, 100), 1E-4,
318 test_svd_compression_by_tol_f64_thick: f64, (50, 100), 1E-4,
319 test_svd_compression_by_tol_c64_thick: ndarray_linalg::c64, (50, 100), 1E-4,
320 }
321}