1pub mod preconditioning;
24
25pub mod sketching;
27
28pub mod rand_nla;
30
31use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
32use scirs2_core::numeric::{Float, NumAssign};
33use scirs2_core::random::prelude::*;
34use scirs2_core::random::{Distribution, Normal};
35use std::fmt::Debug;
36use std::iter::Sum;
37
38use crate::decomposition::{qr, svd};
39use crate::error::{LinalgError, LinalgResult};
40
41#[derive(Debug, Clone)]
43pub struct RandomizedConfig {
44 pub rank: usize,
46 pub oversampling: usize,
48 pub power_iterations: usize,
50 pub seed: Option<u64>,
52}
53
54impl RandomizedConfig {
55 pub fn new(rank: usize) -> Self {
57 Self {
58 rank,
59 oversampling: 10,
60 power_iterations: 2,
61 seed: None,
62 }
63 }
64
65 pub fn with_oversampling(mut self, oversampling: usize) -> Self {
67 self.oversampling = oversampling;
68 self
69 }
70
71 pub fn with_power_iterations(mut self, power_iterations: usize) -> Self {
73 self.power_iterations = power_iterations;
74 self
75 }
76
77 pub fn with_seed(mut self, seed: u64) -> Self {
79 self.seed = Some(seed);
80 self
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct RandomizedPcaResult<F> {
87 pub components: Array2<F>,
89 pub explained_variance: Array1<F>,
91 pub explained_variance_ratio: Array1<F>,
93 pub singular_values: Array1<F>,
95 pub mean: Array1<F>,
97}
98
99fn gaussian_random_matrix<F>(rows: usize, cols: usize) -> LinalgResult<Array2<F>>
104where
105 F: Float + NumAssign + 'static,
106{
107 let mut rng = scirs2_core::random::rng();
108 let normal = Normal::new(0.0, 1.0).map_err(|e| {
109 LinalgError::ComputationError(format!("Failed to create normal distribution: {e}"))
110 })?;
111
112 let mut omega = Array2::zeros((rows, cols));
113 for i in 0..rows {
114 for j in 0..cols {
115 omega[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
116 }
117 }
118 Ok(omega)
119}
120
121fn thin_orthogonalize<F>(y: &ArrayView2<F>, max_cols: usize) -> LinalgResult<Array2<F>>
126where
127 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
128{
129 let (m, n_cols) = y.dim();
130 let target = max_cols.min(n_cols).min(m);
131
132 if m >= n_cols {
133 let (q_full, _) = qr(y, None)?;
135 let actual = target.min(q_full.ncols());
137 Ok(q_full.slice(s![.., ..actual]).to_owned())
138 } else {
139 let (u, _, _) = svd(y, false, None)?;
141 let actual = target.min(u.ncols());
142 Ok(u.slice(s![.., ..actual]).to_owned())
143 }
144}
145
146pub fn randomized_range_finder<F>(
171 a: &ArrayView2<F>,
172 rank: usize,
173 oversampling: Option<usize>,
174 power_iterations: Option<usize>,
175) -> LinalgResult<Array2<F>>
176where
177 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
178{
179 let (m, n) = a.dim();
180 let p = oversampling.unwrap_or(10);
181 let q_iters = power_iterations.unwrap_or(0);
182 let l = (rank + p).min(m).min(n);
184
185 if rank == 0 {
186 return Err(LinalgError::InvalidInput(
187 "Target rank must be greater than 0".to_string(),
188 ));
189 }
190 if rank > m.min(n) {
191 return Err(LinalgError::InvalidInput(format!(
192 "Target rank ({rank}) exceeds min(m, n) = {}",
193 m.min(n)
194 )));
195 }
196
197 let omega = gaussian_random_matrix::<F>(n, l)?;
199
200 let mut y = a.dot(&omega);
202
203 for _ in 0..q_iters {
206 let q_y = thin_orthogonalize(&y.view(), l)?;
208
209 let z = a.t().dot(&q_y);
211
212 let q_z = thin_orthogonalize(&z.view(), l)?;
214
215 y = a.dot(&q_z);
217 }
218
219 let q_trunc = thin_orthogonalize(&y.view(), l)?;
221
222 Ok(q_trunc)
223}
224
225pub fn adaptive_range_finder<F>(
245 a: &ArrayView2<F>,
246 tolerance: F,
247 max_rank: Option<usize>,
248 block_size: Option<usize>,
249) -> LinalgResult<Array2<F>>
250where
251 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
252{
253 let (m, n) = a.dim();
254 let max_r = max_rank.unwrap_or(m.min(n));
255 let bs = block_size.unwrap_or(5);
256
257 if tolerance <= F::zero() {
258 return Err(LinalgError::InvalidInput(
259 "Tolerance must be positive".to_string(),
260 ));
261 }
262
263 let mut q_cols: Vec<Array1<F>> = Vec::new();
264 let mut current_rank = 0;
265
266 while current_rank < max_r {
267 let add_count = bs.min(max_r - current_rank);
268
269 let omega = gaussian_random_matrix::<F>(n, add_count)?;
271 let mut y_block = a.dot(&omega);
272
273 for q_col in &q_cols {
275 for j in 0..add_count {
276 let mut y_col = y_block.column(j).to_owned();
277 let dot: F = y_col
278 .iter()
279 .zip(q_col.iter())
280 .fold(F::zero(), |acc, (&yi, &qi)| acc + yi * qi);
281 for i in 0..m {
282 y_col[i] -= dot * q_col[i];
283 }
284 for i in 0..m {
285 y_block[[i, j]] = y_col[i];
286 }
287 }
288 }
289
290 let q_new = if y_block.nrows() >= y_block.ncols() {
292 let (q_tmp, _) = qr(&y_block.view(), None)?;
293 q_tmp
294 } else {
295 let (u_tmp, _, _) = svd(&y_block.view(), false, None)?;
296 u_tmp
297 };
298
299 let mut all_below_tol = true;
301 let cols_to_add = add_count.min(q_new.ncols());
302 for j in 0..cols_to_add {
303 let col = q_new.column(j);
304 let norm: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
305 if norm > tolerance {
306 all_below_tol = false;
307 q_cols.push(col.to_owned());
308 current_rank += 1;
309 }
310 }
311
312 if all_below_tol {
313 break;
314 }
315 }
316
317 if q_cols.is_empty() {
318 return Err(LinalgError::ComputationError(
319 "Adaptive range finder found no significant directions".to_string(),
320 ));
321 }
322
323 let k = q_cols.len();
325 let mut q = Array2::zeros((m, k));
326 for (j, col) in q_cols.iter().enumerate() {
327 for i in 0..m {
328 q[[i, j]] = col[i];
329 }
330 }
331
332 if q.nrows() >= q.ncols() {
334 let (q_final, _) = qr(&q.view(), None)?;
335 let k_final = k.min(q_final.ncols());
336 Ok(q_final.slice(s![.., ..k_final]).to_owned())
337 } else {
338 let (u_final, _, _) = svd(&q.view(), false, None)?;
339 let k_final = k.min(u_final.ncols());
340 Ok(u_final.slice(s![.., ..k_final]).to_owned())
341 }
342}
343
344pub fn randomized_svd<F>(
369 a: &ArrayView2<F>,
370 config: &RandomizedConfig,
371) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
372where
373 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
374{
375 let k = config.rank;
376 let (m, n) = a.dim();
377
378 if k == 0 {
379 return Err(LinalgError::InvalidInput(
380 "Target rank must be greater than 0".to_string(),
381 ));
382 }
383 if k > m.min(n) {
384 return Err(LinalgError::InvalidInput(format!(
385 "Target rank ({k}) exceeds min(m, n) = {}",
386 m.min(n)
387 )));
388 }
389
390 let q = randomized_range_finder(
392 a,
393 k,
394 Some(config.oversampling),
395 Some(config.power_iterations),
396 )?;
397
398 let b = q.t().dot(a);
400
401 let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
403
404 let u = q.dot(&u_b);
406
407 let k_actual = k.min(sigma.len()).min(u.ncols()).min(vt.nrows());
409 let u_k = u.slice(s![.., ..k_actual]).to_owned();
410 let s_k = sigma.slice(s![..k_actual]).to_owned();
411 let vt_k = vt.slice(s![..k_actual, ..]).to_owned();
412
413 Ok((u_k, s_k, vt_k))
414}
415
416pub fn single_pass_svd<F>(
448 a: &ArrayView2<F>,
449 rank: usize,
450 oversampling: Option<usize>,
451) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
452where
453 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
454{
455 let (m, n) = a.dim();
456 let p = oversampling.unwrap_or(10);
457 let l = (rank + p).min(m).min(n);
458
459 if rank == 0 {
460 return Err(LinalgError::InvalidInput(
461 "Target rank must be greater than 0".to_string(),
462 ));
463 }
464 if rank > m.min(n) {
465 return Err(LinalgError::InvalidInput(format!(
466 "Target rank ({rank}) exceeds min(m, n) = {}",
467 m.min(n)
468 )));
469 }
470
471 let omega = gaussian_random_matrix::<F>(n, l)?;
473 let psi = gaussian_random_matrix::<F>(m, l)?;
474
475 let y = a.dot(&omega);
477 let z = a.t().dot(&psi);
478
479 let q = if y.nrows() >= y.ncols() {
481 let (q_tmp, _) = qr(&y.view(), None)?;
482 let l_a = l.min(q_tmp.ncols());
483 q_tmp.slice(s![.., ..l_a]).to_owned()
484 } else {
485 let (u_tmp, _, _) = svd(&y.view(), false, None)?;
486 let l_a = l.min(u_tmp.ncols()).min(m);
487 u_tmp.slice(s![.., ..l_a]).to_owned()
488 };
489
490 let b = q.t().dot(a);
509
510 let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
512
513 let u = q.dot(&u_b);
515
516 let k = rank.min(sigma.len()).min(u.ncols()).min(vt.nrows());
518 let u_k = u.slice(s![.., ..k]).to_owned();
519 let s_k = sigma.slice(s![..k]).to_owned();
520 let vt_k = vt.slice(s![..k, ..]).to_owned();
521
522 Ok((u_k, s_k, vt_k))
523}
524
525pub fn randomized_low_rank<F>(
545 a: &ArrayView2<F>,
546 rank: usize,
547 oversampling: Option<usize>,
548 power_iterations: Option<usize>,
549) -> LinalgResult<(Array2<F>, Array2<F>)>
550where
551 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
552{
553 let (m, n) = a.dim();
554
555 if rank == 0 {
556 return Err(LinalgError::InvalidInput(
557 "Target rank must be greater than 0".to_string(),
558 ));
559 }
560 if rank > m.min(n) {
561 return Err(LinalgError::InvalidInput(format!(
562 "Target rank ({rank}) exceeds min(m, n) = {}",
563 m.min(n)
564 )));
565 }
566
567 let q = randomized_range_finder(a, rank, oversampling, power_iterations)?;
569
570 let b = q.t().dot(a);
572
573 let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
576
577 let k = rank.min(sigma.len()).min(u_b.ncols()).min(vt.nrows());
578
579 let u_bk = u_b.slice(s![.., ..k]).to_owned();
581 let mut l = q.dot(&u_bk);
582 for j in 0..k {
583 let sj = sigma[j];
584 for i in 0..m {
585 l[[i, j]] *= sj;
586 }
587 }
588
589 let r = vt.slice(s![..k, ..]).to_owned();
591
592 Ok((l, r))
593}
594
595pub fn approximation_error<F>(a: &ArrayView2<F>, q: &ArrayView2<F>) -> LinalgResult<F>
606where
607 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
608{
609 let (m, n) = a.dim();
610 if q.nrows() != m {
611 return Err(LinalgError::DimensionError(format!(
612 "Q has {} rows but A has {} rows",
613 q.nrows(),
614 m
615 )));
616 }
617
618 let qt_a = q.t().dot(a);
620 let q_qt_a = q.dot(&qt_a);
621
622 let mut frobenius_sq = F::zero();
623 for i in 0..m {
624 for j in 0..n {
625 let diff = a[[i, j]] - q_qt_a[[i, j]];
626 frobenius_sq += diff * diff;
627 }
628 }
629
630 Ok(frobenius_sq.sqrt())
631}
632
633pub fn randomized_pca<F>(
653 data: &ArrayView2<F>,
654 n_components: usize,
655 whiten: bool,
656 power_iterations: Option<usize>,
657) -> LinalgResult<RandomizedPcaResult<F>>
658where
659 F: Float
660 + NumAssign
661 + Sum
662 + Debug
663 + scirs2_core::ndarray::ScalarOperand
664 + Send
665 + Sync
666 + 'static,
667{
668 let (n_samples, n_features) = data.dim();
669
670 if n_components == 0 {
671 return Err(LinalgError::InvalidInput(
672 "Number of components must be greater than 0".to_string(),
673 ));
674 }
675 if n_components > n_features.min(n_samples) {
676 return Err(LinalgError::InvalidInput(format!(
677 "n_components ({n_components}) exceeds min(n_samples, n_features) = {}",
678 n_features.min(n_samples)
679 )));
680 }
681
682 let mut mean = Array1::zeros(n_features);
684 let n_f = F::from(n_samples)
685 .ok_or_else(|| LinalgError::ComputationError("Failed to convert n_samples".to_string()))?;
686
687 for j in 0..n_features {
688 let col_sum: F = data.column(j).sum();
689 mean[j] = col_sum / n_f;
690 }
691
692 let mut centered = data.to_owned();
693 for i in 0..n_samples {
694 for j in 0..n_features {
695 centered[[i, j]] -= mean[j];
696 }
697 }
698
699 let config = RandomizedConfig::new(n_components)
701 .with_oversampling(10)
702 .with_power_iterations(power_iterations.unwrap_or(2));
703
704 let (u, sigma, vt) = randomized_svd(¢ered.view(), &config)?;
705
706 let k = sigma.len();
707
708 let denom = F::from(n_samples.saturating_sub(1).max(1)).ok_or_else(|| {
710 LinalgError::ComputationError("Failed to convert denominator".to_string())
711 })?;
712
713 let explained_variance = sigma.mapv(|s| s * s / denom);
714
715 let total_var = {
717 let mut total = F::zero();
718 for j in 0..n_features {
719 let col = centered.column(j);
720 let col_var: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x) / denom;
721 total += col_var;
722 }
723 total
724 };
725
726 let explained_variance_ratio = if total_var > F::zero() {
727 explained_variance.mapv(|v| v / total_var)
728 } else {
729 Array1::zeros(k)
730 };
731
732 let components = if whiten {
734 let mut whitened = vt.slice(s![..k, ..]).to_owned();
736 for i in 0..k {
737 if sigma[i] > F::epsilon() {
738 let scale = F::one() / sigma[i];
739 for j in 0..n_features {
740 whitened[[i, j]] *= scale;
741 }
742 }
743 }
744 whitened
745 } else {
746 vt.slice(s![..k, ..]).to_owned()
747 };
748
749 Ok(RandomizedPcaResult {
750 components,
751 explained_variance,
752 explained_variance_ratio,
753 singular_values: sigma.slice(s![..k]).to_owned(),
754 mean,
755 })
756}
757
758pub fn pca_transform<F>(
771 data: &ArrayView2<F>,
772 pca_result: &RandomizedPcaResult<F>,
773) -> LinalgResult<Array2<F>>
774where
775 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
776{
777 let (n_samples, n_features) = data.dim();
778 if n_features != pca_result.mean.len() {
779 return Err(LinalgError::DimensionError(format!(
780 "Data has {} features but PCA was fitted with {} features",
781 n_features,
782 pca_result.mean.len()
783 )));
784 }
785
786 let mut centered = data.to_owned();
788 for i in 0..n_samples {
789 for j in 0..n_features {
790 centered[[i, j]] -= pca_result.mean[j];
791 }
792 }
793
794 let transformed = centered.dot(&pca_result.components.t());
797
798 Ok(transformed)
799}
800
801pub fn pca_inverse_transform<F>(
812 transformed: &ArrayView2<F>,
813 pca_result: &RandomizedPcaResult<F>,
814) -> LinalgResult<Array2<F>>
815where
816 F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
817{
818 let (n_samples, n_components) = transformed.dim();
819 let n_features = pca_result.mean.len();
820
821 if n_components != pca_result.components.nrows() {
822 return Err(LinalgError::DimensionError(format!(
823 "Transformed data has {} components but PCA has {} components",
824 n_components,
825 pca_result.components.nrows()
826 )));
827 }
828
829 let mut reconstructed = transformed.dot(&pca_result.components);
831
832 for i in 0..n_samples {
833 for j in 0..n_features {
834 reconstructed[[i, j]] += pca_result.mean[j];
835 }
836 }
837
838 Ok(reconstructed)
839}
840
841#[cfg(test)]
846mod tests {
847 use super::*;
848 use scirs2_core::ndarray::array;
849
850 fn make_low_rank_matrix(m: usize, n: usize, rank: usize) -> Array2<f64> {
851 let mut rng = scirs2_core::random::rng();
852 let normal =
853 Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
854 let mut a_left = Array2::zeros((m, rank));
855 let mut a_right = Array2::zeros((rank, n));
856 for i in 0..m {
857 for j in 0..rank {
858 a_left[[i, j]] = normal.sample(&mut rng);
859 }
860 }
861 for i in 0..rank {
862 for j in 0..n {
863 a_right[[i, j]] = normal.sample(&mut rng);
864 }
865 }
866 a_left.dot(&a_right)
867 }
868
869 #[test]
870 fn test_randomized_range_finder_basic() {
871 let a = array![
872 [3.0, 1.0, 0.5],
873 [1.0, 3.0, 0.5],
874 [0.5, 0.5, 2.0],
875 [1.0, 1.0, 1.0]
876 ];
877
878 let q = randomized_range_finder(&a.view(), 2, Some(1), Some(1));
879 assert!(q.is_ok());
880 let q = q.expect("range finder failed");
881 assert_eq!(q.nrows(), 4);
882 assert!(q.ncols() >= 2);
883
884 let qtq = q.t().dot(&q);
886 for i in 0..qtq.nrows() {
887 for j in 0..qtq.ncols() {
888 if i == j {
889 assert!(
890 (qtq[[i, j]] - 1.0).abs() < 1e-6,
891 "Q^TQ not identity on diagonal"
892 );
893 } else {
894 assert!(qtq[[i, j]].abs() < 1e-6, "Q^TQ not identity off-diagonal");
895 }
896 }
897 }
898 }
899
900 #[test]
901 fn test_randomized_range_finder_error_cases() {
902 let a = array![[1.0, 2.0], [3.0, 4.0]];
903 assert!(randomized_range_finder(&a.view(), 0, None, None).is_err());
904 assert!(randomized_range_finder(&a.view(), 5, None, None).is_err());
905 }
906
907 #[test]
908 fn test_adaptive_range_finder() {
909 let a = make_low_rank_matrix(20, 15, 3);
910 let q = adaptive_range_finder(&a.view(), 1e-6, Some(10), Some(2));
911 assert!(q.is_ok());
912 let q = q.expect("adaptive range finder failed");
913 assert!(q.ncols() >= 3, "Should detect at least rank 3");
914 }
915
916 #[test]
917 fn test_randomized_svd_basic() {
918 let a = array![
919 [3.0, 1.0, 0.5],
920 [1.0, 3.0, 0.5],
921 [0.5, 0.5, 2.0],
922 [1.0, 1.0, 1.0]
923 ];
924
925 let config = RandomizedConfig::new(2)
926 .with_oversampling(1)
927 .with_power_iterations(2);
928 let result = randomized_svd(&a.view(), &config);
929 assert!(result.is_ok());
930 let (u, s, vt) = result.expect("randomized SVD failed");
931
932 assert_eq!(u.nrows(), 4);
933 assert_eq!(u.ncols(), 2);
934 assert_eq!(s.len(), 2);
935 assert_eq!(vt.nrows(), 2);
936 assert_eq!(vt.ncols(), 3);
937
938 assert!(s[0] > 0.0);
940 assert!(s[0] >= s[1]);
941 }
942
943 #[test]
944 fn test_randomized_svd_low_rank() {
945 let a = make_low_rank_matrix(30, 20, 3);
946 let config = RandomizedConfig::new(3).with_power_iterations(3);
947 let result = randomized_svd(&a.view(), &config);
948 assert!(result.is_ok());
949
950 let (u, s, vt) = result.expect("randomized SVD failed");
951
952 let mut reconstructed = Array2::zeros((30, 20));
954 for i in 0..30 {
955 for j in 0..20 {
956 let mut val = 0.0;
957 for k in 0..3 {
958 val += u[[i, k]] * s[k] * vt[[k, j]];
959 }
960 reconstructed[[i, j]] = val;
961 }
962 }
963
964 let mut error = 0.0;
965 let mut total = 0.0;
966 for i in 0..30 {
967 for j in 0..20 {
968 let diff = a[[i, j]] - reconstructed[[i, j]];
969 error += diff * diff;
970 total += a[[i, j]] * a[[i, j]];
971 }
972 }
973 let rel_error = (error / total).sqrt();
974 assert!(
975 rel_error < 0.1,
976 "Reconstruction error too large: {rel_error}"
977 );
978 }
979
980 #[test]
981 fn test_randomized_svd_error_cases() {
982 let a = array![[1.0, 2.0], [3.0, 4.0]];
983 let config0 = RandomizedConfig::new(0);
984 assert!(randomized_svd(&a.view(), &config0).is_err());
985
986 let config5 = RandomizedConfig::new(5);
987 assert!(randomized_svd(&a.view(), &config5).is_err());
988 }
989
990 #[test]
991 fn test_single_pass_svd() {
992 let a = array![
993 [3.0, 1.0, 0.5],
994 [1.0, 3.0, 0.5],
995 [0.5, 0.5, 2.0],
996 [1.0, 1.0, 1.0]
997 ];
998
999 let result = single_pass_svd(&a.view(), 2, Some(1));
1000 assert!(result.is_ok());
1001 let (u, s, vt) = result.expect("single pass SVD failed");
1002
1003 assert_eq!(u.nrows(), 4);
1004 assert_eq!(u.ncols(), 2);
1005 assert_eq!(s.len(), 2);
1006 assert_eq!(vt.nrows(), 2);
1007 assert_eq!(vt.ncols(), 3);
1008 }
1009
1010 #[test]
1011 fn test_single_pass_svd_errors() {
1012 let a = array![[1.0, 2.0], [3.0, 4.0]];
1013 assert!(single_pass_svd(&a.view(), 0, None).is_err());
1014 assert!(single_pass_svd(&a.view(), 5, None).is_err());
1015 }
1016
1017 #[test]
1018 fn test_randomized_low_rank() {
1019 let a = make_low_rank_matrix(20, 15, 3);
1020 let result = randomized_low_rank(&a.view(), 3, Some(5), Some(2));
1021 assert!(result.is_ok());
1022 let (l, r) = result.expect("low rank failed");
1023
1024 assert_eq!(l.nrows(), 20);
1025 assert_eq!(l.ncols(), 3);
1026 assert_eq!(r.nrows(), 3);
1027 assert_eq!(r.ncols(), 15);
1028
1029 let approx = l.dot(&r);
1031 let mut error = 0.0;
1032 let mut total = 0.0;
1033 for i in 0..20 {
1034 for j in 0..15 {
1035 let diff = a[[i, j]] - approx[[i, j]];
1036 error += diff * diff;
1037 total += a[[i, j]] * a[[i, j]];
1038 }
1039 }
1040 let rel_error = if total > 0.0 {
1041 (error / total).sqrt()
1042 } else {
1043 0.0
1044 };
1045 assert!(
1046 rel_error < 0.2,
1047 "Low-rank approximation error too large: {rel_error}"
1048 );
1049 }
1050
1051 #[test]
1052 fn test_randomized_low_rank_errors() {
1053 let a = array![[1.0, 2.0], [3.0, 4.0]];
1054 assert!(randomized_low_rank(&a.view(), 0, None, None).is_err());
1055 assert!(randomized_low_rank(&a.view(), 5, None, None).is_err());
1056 }
1057
1058 #[test]
1059 fn test_approximation_error() {
1060 let a = array![[3.0, 1.0], [1.0, 3.0], [0.5, 0.5]];
1061 let q =
1062 randomized_range_finder(&a.view(), 2, Some(0), Some(1)).expect("range finder failed");
1063 let err = approximation_error(&a.view(), &q.view());
1064 assert!(err.is_ok());
1065 let err_val = err.expect("approx error failed");
1066 assert!(
1067 err_val < 1e-6,
1068 "Full-rank approximation error should be small"
1069 );
1070 }
1071
1072 #[test]
1073 fn test_approximation_error_dimension_mismatch() {
1074 let a = array![[1.0, 2.0], [3.0, 4.0]];
1075 let q = array![[1.0], [0.0], [0.0]]; assert!(approximation_error(&a.view(), &q.view()).is_err());
1077 }
1078
1079 #[test]
1080 fn test_randomized_pca_basic() {
1081 let mut data = Array2::zeros((50, 5));
1083 let mut rng = scirs2_core::random::rng();
1084 let normal =
1085 Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
1086
1087 for i in 0..50 {
1088 let c1 = normal.sample(&mut rng);
1089 let c2 = normal.sample(&mut rng);
1090 data[[i, 0]] = c1 * 3.0;
1091 data[[i, 1]] = c1 * 3.0 + normal.sample(&mut rng) * 0.1;
1092 data[[i, 2]] = c2 * 2.0;
1093 data[[i, 3]] = c2 * 2.0 + normal.sample(&mut rng) * 0.1;
1094 data[[i, 4]] = normal.sample(&mut rng) * 0.01;
1095 }
1096
1097 let result = randomized_pca(&data.view(), 2, false, Some(3));
1098 assert!(result.is_ok());
1099 let pca = result.expect("PCA failed");
1100
1101 assert_eq!(pca.components.nrows(), 2);
1102 assert_eq!(pca.components.ncols(), 5);
1103 assert_eq!(pca.explained_variance.len(), 2);
1104 assert_eq!(pca.explained_variance_ratio.len(), 2);
1105 assert_eq!(pca.singular_values.len(), 2);
1106 assert_eq!(pca.mean.len(), 5);
1107
1108 let total_explained: f64 = pca.explained_variance_ratio.sum();
1110 assert!(
1111 total_explained > 0.8,
1112 "Top 2 components should explain >80% variance, got {total_explained}"
1113 );
1114 }
1115
1116 #[test]
1117 fn test_randomized_pca_whiten() {
1118 let data = array![
1119 [1.0, 2.0, 3.0],
1120 [4.0, 5.0, 6.0],
1121 [7.0, 8.0, 9.0],
1122 [10.0, 11.0, 12.0],
1123 [13.0, 14.0, 15.0]
1124 ];
1125
1126 let result = randomized_pca(&data.view(), 2, true, Some(1));
1127 assert!(result.is_ok());
1128 let pca = result.expect("whitened PCA failed");
1129 assert_eq!(pca.components.nrows(), 2);
1130 }
1131
1132 #[test]
1133 fn test_randomized_pca_error_cases() {
1134 let data = array![[1.0, 2.0], [3.0, 4.0]];
1135 assert!(randomized_pca(&data.view(), 0, false, None).is_err());
1136 assert!(randomized_pca(&data.view(), 5, false, None).is_err());
1137 }
1138
1139 #[test]
1140 fn test_pca_transform_and_inverse() {
1141 let data = array![
1142 [1.0, 2.0, 3.0],
1143 [4.0, 5.0, 6.0],
1144 [7.0, 8.0, 9.0],
1145 [10.0, 11.0, 12.0]
1146 ];
1147
1148 let pca = randomized_pca(&data.view(), 2, false, Some(2)).expect("PCA failed");
1149
1150 let transformed = pca_transform(&data.view(), &pca).expect("transform failed");
1152 assert_eq!(transformed.nrows(), 4);
1153 assert_eq!(transformed.ncols(), 2);
1154
1155 let reconstructed =
1157 pca_inverse_transform(&transformed.view(), &pca).expect("inverse transform failed");
1158 assert_eq!(reconstructed.nrows(), 4);
1159 assert_eq!(reconstructed.ncols(), 3);
1160
1161 for i in 0..4 {
1163 for j in 0..3 {
1164 assert!(
1165 (data[[i, j]] - reconstructed[[i, j]]).abs() < 1.0,
1166 "Reconstruction error too large at [{i}, {j}]"
1167 );
1168 }
1169 }
1170 }
1171
1172 #[test]
1173 fn test_pca_transform_dimension_mismatch() {
1174 let data = array![[1.0, 2.0], [3.0, 4.0]];
1175 let pca = randomized_pca(&data.view(), 1, false, Some(1)).expect("PCA failed");
1176
1177 let wrong_data = array![[1.0, 2.0, 3.0]]; assert!(pca_transform(&wrong_data.view(), &pca).is_err());
1179 }
1180
1181 #[test]
1182 fn test_config_builder() {
1183 let config = RandomizedConfig::new(5)
1184 .with_oversampling(20)
1185 .with_power_iterations(3)
1186 .with_seed(42);
1187
1188 assert_eq!(config.rank, 5);
1189 assert_eq!(config.oversampling, 20);
1190 assert_eq!(config.power_iterations, 3);
1191 assert_eq!(config.seed, Some(42));
1192 }
1193
1194 #[test]
1195 fn test_randomized_svd_identity_like() {
1196 let a = array![
1197 [1.0, 0.0, 0.0],
1198 [0.0, 1.0, 0.0],
1199 [0.0, 0.0, 1.0],
1200 [0.0, 0.0, 0.0]
1201 ];
1202
1203 let config = RandomizedConfig::new(3)
1204 .with_oversampling(0)
1205 .with_power_iterations(1);
1206 let result = randomized_svd(&a.view(), &config);
1207 assert!(result.is_ok());
1208 let (_u, s, _vt) = result.expect("SVD of identity-like failed");
1209
1210 for i in 0..s.len() {
1212 assert!(
1213 (s[i] - 1.0).abs() < 0.1,
1214 "Singular value {} = {}, expected ~1.0",
1215 i,
1216 s[i]
1217 );
1218 }
1219 }
1220}