1use crate::Df64;
7use crate::col_piv_qr::ColPivQR;
8use crate::numeric::CustomNumeric;
9use mdarray::DTensor;
10use nalgebra::{ComplexField, DMatrix, DVector, RealField};
11use num_traits::{One, ToPrimitive, Zero};
12
13#[derive(Debug, Clone)]
15pub struct SVDResult<T> {
16 pub u: DMatrix<T>,
18 pub s: DVector<T>,
20 pub v: DMatrix<T>,
22 pub rank: usize,
24}
25
26#[derive(Debug, Clone)]
28pub struct TSVDConfig<T> {
29 pub rtol: T,
31}
32
33impl<T> TSVDConfig<T> {
34 pub fn new(rtol: T) -> Self {
35 Self { rtol }
36 }
37}
38
39#[derive(Debug, thiserror::Error)]
41pub enum TSVDError {
42 #[error("Matrix is empty")]
43 EmptyMatrix,
44 #[error("Invalid tolerance: {0}")]
45 InvalidTolerance(String),
46}
47
48#[inline]
55fn get_epsilon_for_svd<T: RealField + Copy>() -> T {
56 use std::any::TypeId;
57
58 if TypeId::of::<T>() == TypeId::of::<f64>() {
59 unsafe { std::ptr::read(&f64::EPSILON as *const f64 as *const T) }
61 } else if TypeId::of::<T>() == TypeId::of::<crate::Df64>() {
62 unsafe { std::ptr::read(&crate::Df64::EPSILON as *const crate::Df64 as *const T) }
64 } else {
65 T::from_f64(1e-15).unwrap_or(T::one() * T::from_f64(1e-15).unwrap_or(T::one()))
67 }
68}
69
70pub fn svd_decompose<T>(matrix: &DMatrix<T>, rtol: f64) -> SVDResult<T>
82where
83 T: ComplexField + RealField + Copy + nalgebra::RealField + ToPrimitive,
84{
85 let eps = get_epsilon_for_svd::<T>();
91
92 let svd = matrix
96 .clone()
97 .try_svd(true, true, eps, 0)
98 .expect("SVD computation failed");
99
100 let u_matrix = svd.u.unwrap();
102 let s_vector = svd.singular_values; let v_t_matrix = svd.v_t.unwrap();
104
105 let rank = calculate_rank_from_vector(&s_vector, rtol);
108
109 let u = DMatrix::from(u_matrix.columns(0, rank));
111 let s = DVector::from(s_vector.rows(0, rank));
112 let v = DMatrix::from(v_t_matrix.rows(0, rank).transpose());
113
114 SVDResult { u, s, v, rank }
115}
116
117fn calculate_rank_from_vector<T>(singular_values: &DVector<T>, rtol: f64) -> usize
130where
131 T: RealField + Copy + ToPrimitive,
132{
133 if singular_values.is_empty() {
134 return 0;
135 }
136
137 let max_sv = singular_values[0];
139 let threshold = max_sv * T::from_f64(rtol).unwrap_or(T::zero());
140
141 let mut rank = 0;
142 for &sv in singular_values.iter() {
143 if sv > threshold {
144 rank += 1;
145 } else {
146 break;
148 }
149 }
150
151 rank
152}
153
154fn calculate_rank_from_r<T: RealField>(r_matrix: &DMatrix<T>, rtol: T) -> usize
156where
157 T: ComplexField + RealField + Copy,
158{
159 let dim = r_matrix.nrows().min(r_matrix.ncols());
160 let mut rank = dim;
161
162 let mut max_diag_abs = Zero::zero();
164 for i in 0..dim {
165 let diag_abs = ComplexField::abs(r_matrix[(i, i)]);
166 if diag_abs > max_diag_abs {
167 max_diag_abs = diag_abs;
168 }
169 }
170
171 if max_diag_abs == Zero::zero() {
173 return 0;
174 }
175
176 for i in 0..dim {
178 let diag_abs = ComplexField::abs(r_matrix[(i, i)]);
179
180 if diag_abs < rtol * max_diag_abs {
182 rank = i;
183 break;
184 }
185 }
186
187 rank
188}
189
190pub fn tsvd<T>(matrix: &DMatrix<T>, config: TSVDConfig<T>) -> Result<SVDResult<T>, TSVDError>
204where
205 T: ComplexField
206 + RealField
207 + Copy
208 + nalgebra::RealField
209 + std::fmt::Debug
210 + ToPrimitive
211 + CustomNumeric,
212{
213 let (m, n) = matrix.shape();
214
215 if m == 0 || n == 0 {
216 return Err(TSVDError::EmptyMatrix);
217 }
218
219 if config.rtol <= Zero::zero() || config.rtol >= One::one() {
220 return Err(TSVDError::InvalidTolerance(format!(
221 "Tolerance must be in (0, 1), got {:?}",
222 config.rtol
223 )));
224 }
225
226 let qr_rtol = Some(config.rtol.clone().modulus());
229 let qr = ColPivQR::new_with_rtol(matrix.clone(), qr_rtol);
230 let q_matrix = qr.q();
231 let r_matrix = qr.r();
232 let permutation = qr.p();
233
234 let qr_rank = calculate_rank_from_r(
237 &r_matrix,
238 T::from_f64_unchecked(2.0) * get_epsilon_for_svd::<T>(),
239 );
240
241 if qr_rank == 0 {
242 return Ok(SVDResult {
244 u: DMatrix::zeros(m, 0),
245 s: DVector::zeros(0),
246 v: DMatrix::zeros(n, 0),
247 rank: 0,
248 });
249 }
250
251 let r_truncated: DMatrix<T> = r_matrix.rows(0, qr_rank).into();
253 let rtol_t = config.rtol;
255 let rtol_f64 = rtol_t.to_f64();
256 let svd_result = svd_decompose(&r_truncated, rtol_f64);
257
258 if svd_result.rank == 0 {
259 return Ok(SVDResult {
261 u: DMatrix::zeros(m, 0),
262 s: DVector::zeros(0),
263 v: DMatrix::zeros(n, 0),
264 rank: 0,
265 });
266 }
267
268 let q_truncated: DMatrix<T> = q_matrix.columns(0, qr_rank).into();
271 let u_full = &q_truncated * &svd_result.u;
272
273 let mut v_full = svd_result.v.clone();
278 permutation.inv_permute_rows(&mut v_full);
279
280 let s_full = svd_result.s.clone();
282
283 Ok(SVDResult {
284 u: u_full,
285 s: s_full,
286 v: v_full,
287 rank: svd_result.rank,
288 })
289}
290
291pub fn tsvd_f64(matrix: &DMatrix<f64>, rtol: f64) -> Result<SVDResult<f64>, TSVDError> {
293 tsvd(matrix, TSVDConfig::new(rtol))
294}
295
296pub fn tsvd_df64(matrix: &DMatrix<Df64>, rtol: Df64) -> Result<SVDResult<Df64>, TSVDError> {
298 tsvd(matrix, TSVDConfig::new(rtol))
299}
300
301pub fn tsvd_df64_from_f64(matrix: &DMatrix<f64>, rtol: f64) -> Result<SVDResult<Df64>, TSVDError> {
303 let matrix_df64 = DMatrix::from_fn(matrix.nrows(), matrix.ncols(), |i, j| {
304 Df64::from(matrix[(i, j)])
305 });
306 let rtol_df64 = Df64::from(rtol);
307 tsvd(&matrix_df64, TSVDConfig::new(rtol_df64))
308}
309
310pub fn compute_svd_dtensor<T: CustomNumeric + 'static>(
314 matrix: &DTensor<T, 2>,
315) -> (DTensor<T, 2>, Vec<T>, DTensor<T, 2>) {
316 use nalgebra::DMatrix;
317 use std::any::TypeId;
318
319 if TypeId::of::<T>() == TypeId::of::<f64>() {
321 let matrix_f64 = DMatrix::from_fn(matrix.shape().0, matrix.shape().1, |i, j| {
323 CustomNumeric::to_f64(matrix[[i, j]])
324 });
325
326 let rtol = 2.0 * f64::EPSILON;
328 let result = tsvd(&matrix_f64, TSVDConfig::new(rtol)).expect("TSVD computation failed");
329
330 let u = DTensor::<T, 2>::from_fn([result.u.nrows(), result.u.ncols()], |idx| {
332 let [i, j] = [idx[0], idx[1]];
333 T::from_f64_unchecked(result.u[(i, j)])
334 });
335
336 let s: Vec<T> = result.s.iter().map(|x| T::from_f64_unchecked(*x)).collect();
337
338 let v = DTensor::<T, 2>::from_fn([result.v.nrows(), result.v.ncols()], |idx| {
339 let [i, j] = [idx[0], idx[1]];
340 T::from_f64_unchecked(result.v[(i, j)])
341 });
342
343 (u, s, v)
344 } else if TypeId::of::<T>() == TypeId::of::<Df64>() {
345 let matrix_df64: DMatrix<Df64> =
348 DMatrix::from_fn(matrix.shape().0, matrix.shape().1, |i, j| {
349 unsafe { std::mem::transmute_copy(&matrix[[i, j]]) }
351 });
352
353 let rtol = Df64::from(2.0) * Df64::epsilon();
355 let result = tsvd_df64(&matrix_df64, rtol).expect("TSVD computation failed");
356
357 let u = DTensor::<T, 2>::from_fn([result.u.nrows(), result.u.ncols()], |idx| {
359 let [i, j] = [idx[0], idx[1]];
360 T::convert_from(result.u[(i, j)])
361 });
362
363 let s: Vec<T> = result.s.iter().map(|x| T::convert_from(*x)).collect();
364
365 let v = DTensor::<T, 2>::from_fn([result.v.nrows(), result.v.ncols()], |idx| {
366 let [i, j] = [idx[0], idx[1]];
367 T::convert_from(result.v[(i, j)])
368 });
369
370 (u, s, v)
371 } else {
372 panic!("SVD is only implemented for f64 and Df64");
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use nalgebra::DMatrix;
380 use num_traits::cast::ToPrimitive;
381
382 #[test]
383 fn test_svd_identity_matrix() {
384 let matrix = DMatrix::<f64>::identity(3, 3);
385 let result = svd_decompose(&matrix, 1e-12);
386
387 assert_eq!(result.rank, 3);
388 assert_eq!(result.s.len(), 3);
389 assert_eq!(result.u.nrows(), 3);
390 assert_eq!(result.u.ncols(), 3);
391 assert_eq!(result.v.nrows(), 3);
392 assert_eq!(result.v.ncols(), 3);
393 }
394
395 #[test]
396 fn test_tsvd_identity_matrix() {
397 let matrix = DMatrix::<f64>::identity(3, 3);
398 let result = tsvd_f64(&matrix, 1e-12).unwrap();
399
400 assert_eq!(result.rank, 3);
401 assert_eq!(result.s.len(), 3);
402 }
403
404 #[test]
405 fn test_tsvd_rank_one() {
406 let matrix = DMatrix::<f64>::from_fn(3, 3, |i, j| (i + 1) as f64 * (j + 1) as f64);
407 let result = tsvd_f64(&matrix, 1e-12).unwrap();
408
409 assert_eq!(result.rank, 1);
410 }
411
412 #[test]
413 fn test_tsvd_empty_matrix() {
414 let matrix = DMatrix::<f64>::zeros(0, 0);
415 let result = tsvd_f64(&matrix, 1e-12);
416
417 assert!(matches!(result, Err(TSVDError::EmptyMatrix)));
418 }
419
420 fn create_hilbert_matrix_generic<T>(n: usize) -> DMatrix<T>
423 where
424 T: nalgebra::RealField + From<f64> + Copy + std::ops::Div<Output = T>,
425 {
426 DMatrix::from_fn(n, n, |i, j| {
427 T::one() / T::from((i + j + 1) as f64)
430 })
431 }
432
433 fn reconstruct_matrix_generic<T>(
435 u: &DMatrix<T>,
436 s: &nalgebra::DVector<T>,
437 v: &DMatrix<T>,
438 ) -> DMatrix<T>
439 where
440 T: nalgebra::RealField + Copy,
441 {
442 u * &DMatrix::from_diagonal(s) * &v.transpose()
446 }
447
448 fn frobenius_norm_generic<T>(matrix: &DMatrix<T>) -> f64
450 where
451 T: nalgebra::RealField + Copy + ToPrimitive,
452 {
453 let mut sum = 0.0;
454 for i in 0..matrix.nrows() {
455 for j in 0..matrix.ncols() {
456 let val = matrix[(i, j)].to_f64().unwrap_or(0.0);
457 sum += val * val;
458 }
459 }
460 sum.sqrt()
461 }
462
463 fn test_hilbert_reconstruction_generic<T>(n: usize, rtol: f64, expected_max_error: f64)
465 where
466 T: nalgebra::RealField
467 + From<f64>
468 + Copy
469 + ToPrimitive
470 + std::fmt::Debug
471 + crate::numeric::CustomNumeric,
472 {
473 let h = create_hilbert_matrix_generic::<T>(n);
474
475 let config = TSVDConfig::new(T::from(rtol));
477 let result = tsvd(&h, config).unwrap();
478
479 let h_reconstructed = reconstruct_matrix_generic(&result.u, &result.s, &result.v);
481
482 let error_matrix = &h - &h_reconstructed;
484 let error_norm = frobenius_norm_generic(&error_matrix);
485 let relative_error = error_norm / frobenius_norm_generic(&h);
486
487 assert!(
489 relative_error <= expected_max_error,
490 "Relative reconstruction error {} exceeds expected maximum {}",
491 relative_error,
492 expected_max_error
493 );
494 }
495
496 #[test]
497 fn test_hilbert_5x5_f64_reconstruction() {
498 test_hilbert_reconstruction_generic::<f64>(5, 1e-12, 1e-14);
499 }
500
501 #[test]
502 fn test_hilbert_5x5_df64_reconstruction() {
503 test_hilbert_reconstruction_generic::<Df64>(5, 1e-28, 1e-28);
504 }
505
506 #[test]
507 fn test_hilbert_10x10_f64_reconstruction() {
508 test_hilbert_reconstruction_generic::<f64>(10, 1e-12, 1e-12);
509 }
510
511 #[test]
512 fn test_hilbert_10x10_df64_reconstruction() {
513 test_hilbert_reconstruction_generic::<Df64>(10, 1e-28, 1e-30);
517 }
518
519 #[test]
520 fn test_hilbert_100x100_f64_reconstruction() {
521 test_hilbert_reconstruction_generic::<f64>(100, 1e-12, 1e-12);
523 }
524
525 #[test]
526 fn test_hilbert_100x100_df64_reconstruction() {
527 test_hilbert_reconstruction_generic::<Df64>(100, 1e-28, 1e-28);
529 }
530}