1use crate::error::{StatsError, StatsResult as Result};
7use crate::error_handling_v2::ErrorCode;
8use crate::{unified_error_handling::global_error_handler, validate_or_error};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use statrs::statistics::Statistics;
11
12#[derive(Debug, Clone)]
17pub struct CanonicalCorrelationAnalysis {
18 pub n_components: Option<usize>,
20 pub scale: bool,
22 pub reg_param: f64,
24 pub max_iter: usize,
26 pub tol: f64,
28}
29
30#[derive(Debug, Clone)]
32pub struct CCAResult {
33 pub x_weights: Array2<f64>,
35 pub y_weights: Array2<f64>,
37 pub correlations: Array1<f64>,
39 pub x_loadings: Array2<f64>,
41 pub y_loadings: Array2<f64>,
43 pub x_cross_loadings: Array2<f64>,
45 pub y_cross_loadings: Array2<f64>,
47 pub x_mean: Array1<f64>,
49 pub y_mean: Array1<f64>,
51 pub x_std: Option<Array1<f64>>,
53 pub y_std: Option<Array1<f64>>,
55 pub n_components: usize,
57 pub x_explained_variance_ratio: Array1<f64>,
59 pub y_explained_variance_ratio: Array1<f64>,
61}
62
63impl Default for CanonicalCorrelationAnalysis {
64 fn default() -> Self {
65 Self {
66 n_components: None,
67 scale: true,
68 reg_param: 1e-6,
69 max_iter: 500,
70 tol: 1e-8,
71 }
72 }
73}
74
75impl CanonicalCorrelationAnalysis {
76 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn with_n_components(mut self, ncomponents: usize) -> Self {
83 self.n_components = Some(ncomponents);
84 self
85 }
86
87 pub fn with_scale(mut self, scale: bool) -> Self {
89 self.scale = scale;
90 self
91 }
92
93 pub fn with_reg_param(mut self, regparam: f64) -> Self {
95 self.reg_param = regparam;
96 self
97 }
98
99 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
101 self.max_iter = maxiter;
102 self
103 }
104
105 pub fn with_tolerance(mut self, tol: f64) -> Self {
107 self.tol = tol;
108 self
109 }
110
111 pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<CCAResult> {
113 let handler = global_error_handler();
114 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "CCA fit");
115 validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "CCA fit");
116
117 let (n_samples_x, n_features_x) = x.dim();
118 let (n_samples_y, n_features_y) = y.dim();
119
120 if n_samples_x != n_samples_y {
121 return Err(handler
122 .create_validation_error(
123 ErrorCode::E2001,
124 "CCA fit",
125 "samplesize_mismatch",
126 format!("x: {}, y: {}", n_samples_x, n_samples_y),
127 "X and Y must have the same number of samples",
128 )
129 .error);
130 }
131
132 let n_samples_ = n_samples_x;
133 if n_samples_ < 2 {
134 return Err(handler
135 .create_validation_error(
136 ErrorCode::E2003,
137 "CCA fit",
138 "n_samples_",
139 n_samples_,
140 "CCA requires at least 2 samples",
141 )
142 .error);
143 }
144
145 if n_features_x == 0 || n_features_y == 0 {
146 return Err(handler
147 .create_validation_error(
148 ErrorCode::E2004,
149 "CCA fit",
150 "n_features",
151 format!("x: {}, y: {}", n_features_x, n_features_y),
152 "Both X and Y must have at least one feature",
153 )
154 .error);
155 }
156
157 let max_components = n_features_x.min(n_features_y).min(n_samples_ - 1);
159 let n_components = self
160 .n_components
161 .unwrap_or(max_components)
162 .min(max_components);
163
164 if n_components == 0 {
165 return Err(handler
166 .create_validation_error(
167 ErrorCode::E1001,
168 "CCA fit",
169 "n_components",
170 n_components,
171 "Number of components must be positive",
172 )
173 .error);
174 }
175
176 let (x_centered, x_mean, x_std) = self.center_and_scale(x)?;
178 let (y_centered, y_mean, y_std) = self.center_and_scale(y)?;
179
180 let (cxx, cyy, cxy) = self.compute_covariance_matrices(&x_centered, &y_centered)?;
182
183 let (x_weights, y_weights, correlations) =
185 self.solve_cca_eigenvalue_problem(&cxx, &cyy, &cxy, n_components)?;
186
187 let x_canonical = x_centered.dot(&x_weights);
189 let y_canonical = y_centered.dot(&y_weights);
190
191 let x_loadings = self.compute_loadings(&x_centered, &x_canonical)?;
192 let y_loadings = self.compute_loadings(&y_centered, &y_canonical)?;
193 let x_cross_loadings = self.compute_loadings(&x_centered, &y_canonical)?;
194 let y_cross_loadings = self.compute_loadings(&y_centered, &x_canonical)?;
195
196 let x_explained_variance_ratio =
198 self.compute_explained_variance(&x_centered, &x_canonical)?;
199 let y_explained_variance_ratio =
200 self.compute_explained_variance(&y_centered, &y_canonical)?;
201
202 Ok(CCAResult {
203 x_weights,
204 y_weights,
205 correlations,
206 x_loadings,
207 y_loadings,
208 x_cross_loadings,
209 y_cross_loadings,
210 x_mean,
211 y_mean,
212 x_std,
213 y_std,
214 n_components,
215 x_explained_variance_ratio,
216 y_explained_variance_ratio,
217 })
218 }
219
220 fn center_and_scale(
222 &self,
223 data: ArrayView2<f64>,
224 ) -> Result<(Array2<f64>, Array1<f64>, Option<Array1<f64>>)> {
225 let mean = data.mean_axis(Axis(0)).expect("Operation failed");
226 let mut centered = data.to_owned();
227
228 for mut row in centered.rows_mut() {
230 row -= &mean;
231 }
232
233 if self.scale {
234 let mut std_dev = Array1::zeros(data.ncols());
236 for j in 0..data.ncols() {
237 let col = centered.column(j);
238 let variance = col.mapv(|x| x * x).mean();
239 std_dev[j] = variance.sqrt().max(1e-10); }
241
242 for mut row in centered.rows_mut() {
244 for j in 0..row.len() {
245 row[j] /= std_dev[j];
246 }
247 }
248
249 Ok((centered, mean, Some(std_dev)))
250 } else {
251 Ok((centered, mean, None))
252 }
253 }
254
255 fn compute_covariance_matrices(
257 &self,
258 x: &Array2<f64>,
259 y: &Array2<f64>,
260 ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)> {
261 let n_samples_ = x.nrows() as f64;
262
263 let cxx = x.t().dot(x) / (n_samples_ - 1.0);
265 let cyy = y.t().dot(y) / (n_samples_ - 1.0);
266
267 let cxy = x.t().dot(y) / (n_samples_ - 1.0);
269
270 Ok((cxx, cyy, cxy))
271 }
272
273 fn solve_cca_eigenvalue_problem(
275 &self,
276 cxx: &Array2<f64>,
277 cyy: &Array2<f64>,
278 cxy: &Array2<f64>,
279 n_components: usize,
280 ) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>)> {
281 let cxx_reg = self.regularize_covariance(cxx)?;
283 let cyy_reg = self.regularize_covariance(cyy)?;
284
285 let cxx_inv_sqrt = self.compute_inverse_sqrt(&cxx_reg)?;
287 let cyy_inv_sqrt = self.compute_inverse_sqrt(&cyy_reg)?;
288
289 let k = cxx_inv_sqrt.dot(cxy).dot(&cyy_inv_sqrt);
291
292 let (u, s, vt) = scirs2_linalg::svd(&k.view(), true, None)
294 .map_err(|e| StatsError::ComputationError(format!("SVD failed in CCA: {}", e)))?;
295
296 let n_comp = n_components.min(s.len());
298 let correlations = s.slice(scirs2_core::ndarray::s![..n_comp]).to_owned();
299 let u_comp = u.slice(scirs2_core::ndarray::s![.., ..n_comp]).to_owned();
300 let v_comp = vt
301 .slice(scirs2_core::ndarray::s![..n_comp, ..])
302 .t()
303 .to_owned();
304
305 let x_weights = cxx_inv_sqrt.dot(&u_comp);
307 let y_weights = cyy_inv_sqrt.dot(&v_comp);
308
309 Ok((x_weights, y_weights, correlations))
310 }
311
312 fn regularize_covariance(&self, cov: &Array2<f64>) -> Result<Array2<f64>> {
314 if self.reg_param <= 0.0 {
315 return Ok(cov.clone());
316 }
317
318 let n = cov.nrows();
319 let trace = (0..n).map(|i| cov[[i, i]]).sum::<f64>();
320 let reg_term: Array2<f64> = Array2::eye(n) * (self.reg_param * trace / n as f64);
321
322 Ok(cov + ®_term)
323 }
324
325 fn compute_inverse_sqrt(&self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
327 let (eigenvalues, eigenvectors) =
329 scirs2_linalg::eigh_f64_lapack(&matrix.view()).map_err(|e| {
330 StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
331 })?;
332
333 let min_eigenvalue = eigenvalues.iter().cloned().fold(f64::INFINITY, f64::min);
335 if min_eigenvalue <= 1e-10 {
336 return Err(StatsError::ComputationError(format!(
337 "Matrix is not positive definite (min eigenvalue: {})",
338 min_eigenvalue
339 )));
340 }
341
342 let inv_sqrt_eigenvalues = eigenvalues.mapv(|x: f64| x.sqrt().recip());
344 let mut inv_sqrt = Array2::zeros(matrix.dim());
345
346 for i in 0..eigenvalues.len() {
347 let eigenvec = eigenvectors.column(i);
348 let lambda_inv_sqrt = inv_sqrt_eigenvalues[i];
349
350 for j in 0..matrix.nrows() {
351 for k in 0..matrix.ncols() {
352 inv_sqrt[[j, k]] += lambda_inv_sqrt * eigenvec[j] * eigenvec[k];
353 }
354 }
355 }
356
357 Ok(inv_sqrt)
358 }
359
360 fn compute_loadings(
362 &self,
363 original: &Array2<f64>,
364 canonical: &Array2<f64>,
365 ) -> Result<Array2<f64>> {
366 let n_samples_ = original.nrows() as f64;
367 let n_original = original.ncols();
368 let n_canonical = canonical.ncols();
369
370 let mut loadings = Array2::zeros((n_original, n_canonical));
371
372 for i in 0..n_original {
373 let orig_var = original.column(i);
374 let orig_var_std = (orig_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
375
376 for j in 0..n_canonical {
377 let canon_var = canonical.column(j);
378 let canon_var_std = (canon_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
379
380 if orig_var_std > 1e-10 && canon_var_std > 1e-10 {
381 let covariance = orig_var.dot(&canon_var) / (n_samples_ - 1.0);
382 let correlation = covariance / (orig_var_std * canon_var_std);
383 loadings[[i, j]] = correlation;
384 }
385 }
386 }
387
388 Ok(loadings)
389 }
390
391 fn compute_explained_variance(
393 &self,
394 original: &Array2<f64>,
395 canonical: &Array2<f64>,
396 ) -> Result<Array1<f64>> {
397 let n_samples_ = original.nrows() as f64;
398 let n_canonical = canonical.ncols();
399
400 let total_variance = (0..original.ncols())
402 .map(|i| {
403 let col = original.column(i);
404 col.mapv(|x| x * x).sum() / (n_samples_ - 1.0)
405 })
406 .sum::<f64>();
407
408 if total_variance <= 1e-10 {
409 return Ok(Array1::zeros(n_canonical));
410 }
411
412 let mut explained_variance = Array1::zeros(n_canonical);
414 for j in 0..n_canonical {
415 let canon_var = canonical.column(j);
416 let canon_variance = canon_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0);
417 explained_variance[j] = canon_variance / total_variance;
418 }
419
420 Ok(explained_variance)
421 }
422
423 pub fn transform(
425 &self,
426 x: ArrayView2<f64>,
427 y: ArrayView2<f64>,
428 result: &CCAResult,
429 ) -> Result<(Array2<f64>, Array2<f64>)> {
430 let handler = global_error_handler();
431 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "CCA transform");
432 validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "CCA transform");
433
434 if x.ncols() != result.x_mean.len() {
435 return Err(handler
436 .create_validation_error(
437 ErrorCode::E2001,
438 "CCA transform",
439 "x_features",
440 format!("input: {}, expected: {}", x.ncols(), result.x_mean.len()),
441 "X must have the same number of features as training data",
442 )
443 .error);
444 }
445
446 if y.ncols() != result.y_mean.len() {
447 return Err(handler
448 .create_validation_error(
449 ErrorCode::E2001,
450 "CCA transform",
451 "y_features",
452 format!("input: {}, expected: {}", y.ncols(), result.y_mean.len()),
453 "Y must have the same number of features as training data",
454 )
455 .error);
456 }
457
458 let mut x_processed = x.to_owned();
460 for mut row in x_processed.rows_mut() {
461 row -= &result.x_mean;
462 }
463
464 if let Some(ref x_std) = result.x_std {
465 for mut row in x_processed.rows_mut() {
466 for j in 0..row.len() {
467 row[j] /= x_std[j];
468 }
469 }
470 }
471
472 let mut y_processed = y.to_owned();
474 for mut row in y_processed.rows_mut() {
475 row -= &result.y_mean;
476 }
477
478 if let Some(ref y_std) = result.y_std {
479 for mut row in y_processed.rows_mut() {
480 for j in 0..row.len() {
481 row[j] /= y_std[j];
482 }
483 }
484 }
485
486 let x_canonical = x_processed.dot(&result.x_weights);
488 let y_canonical = y_processed.dot(&result.y_weights);
489
490 Ok((x_canonical, y_canonical))
491 }
492
493 pub fn score(
495 &self,
496 x: ArrayView2<f64>,
497 y: ArrayView2<f64>,
498 result: &CCAResult,
499 ) -> Result<Array1<f64>> {
500 let (x_canonical, y_canonical) = self.transform(x, y, result)?;
501 let n_samples_ = x_canonical.nrows() as f64;
502 let n_components = result.n_components;
503
504 let mut correlations = Array1::zeros(n_components);
505 for i in 0..n_components {
506 let x_comp = x_canonical.column(i);
507 let y_comp = y_canonical.column(i);
508
509 let x_std = (x_comp.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
510 let y_std = (y_comp.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
511
512 if x_std > 1e-10 && y_std > 1e-10 {
513 let covariance = x_comp.dot(&y_comp) / (n_samples_ - 1.0);
514 correlations[i] = covariance / (x_std * y_std);
515 }
516 }
517
518 Ok(correlations)
519 }
520}
521
522#[derive(Debug, Clone)]
526pub struct PLSCanonical {
527 pub n_components: usize,
529 pub scale: bool,
531 pub max_iter: usize,
533 pub tol: f64,
535}
536
537#[derive(Debug, Clone)]
539pub struct PLSResult {
540 pub x_weights: Array2<f64>,
542 pub y_weights: Array2<f64>,
544 pub x_loadings: Array2<f64>,
546 pub y_loadings: Array2<f64>,
548 pub x_scores: Array2<f64>,
550 pub y_scores: Array2<f64>,
552 pub x_rotations: Array2<f64>,
554 pub y_rotations: Array2<f64>,
556 pub x_mean: Array1<f64>,
558 pub y_mean: Array1<f64>,
559 pub x_std: Option<Array1<f64>>,
561 pub y_std: Option<Array1<f64>>,
562}
563
564impl Default for PLSCanonical {
565 fn default() -> Self {
566 Self {
567 n_components: 2,
568 scale: true,
569 max_iter: 500,
570 tol: 1e-6,
571 }
572 }
573}
574
575impl PLSCanonical {
576 pub fn new(_ncomponents: usize) -> Self {
578 Self {
579 n_components: _ncomponents,
580 ..Default::default()
581 }
582 }
583
584 pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<PLSResult> {
586 let handler = global_error_handler();
587 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "PLS fit");
588 validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "PLS fit");
589
590 let (n_samples_, n_x_features) = x.dim();
591 let (n_samples_y, n_y_features) = y.dim();
592
593 if n_samples_ != n_samples_y {
594 return Err(handler
595 .create_validation_error(
596 ErrorCode::E2001,
597 "PLS fit",
598 "samplesize_mismatch",
599 format!("x: {}, y: {}", n_samples_, n_samples_y),
600 "X and Y must have the same number of samples",
601 )
602 .error);
603 }
604
605 let cca = CanonicalCorrelationAnalysis {
607 scale: self.scale,
608 ..Default::default()
609 };
610 let (mut x_current, x_mean, x_std) = cca.center_and_scale(x)?;
611 let (mut y_current, y_mean, y_std) = cca.center_and_scale(y)?;
612
613 let mut x_weights = Array2::zeros((n_x_features, self.n_components));
615 let mut y_weights = Array2::zeros((n_y_features, self.n_components));
616 let mut x_loadings = Array2::zeros((n_x_features, self.n_components));
617 let mut y_loadings = Array2::zeros((n_y_features, self.n_components));
618 let mut x_scores = Array2::zeros((n_samples_, self.n_components));
619 let mut y_scores = Array2::zeros((n_samples_, self.n_components));
620
621 let mut actual_components = 0;
623 for comp in 0..self.n_components {
624 let x_var = x_current.iter().map(|&x| x * x).sum::<f64>();
626 let y_var = y_current.iter().map(|&y| y * y).sum::<f64>();
627
628 if x_var < 1e-12 || y_var < 1e-12 {
629 break;
631 }
632
633 let mut u = y_current.column(0).to_owned();
635 let mut w_old = Array1::zeros(n_x_features);
636
637 let mut converged_inner = false;
638 for _iter in 0..self.max_iter {
639 let w = x_current.t().dot(&u);
641 let w_norm = (w.dot(&w)).sqrt();
642 if w_norm < 1e-10 {
643 converged_inner = false;
645 break;
646 }
647 let w = w / w_norm;
648
649 let t = x_current.dot(&w);
651
652 let c = y_current.t().dot(&t);
654 let c_norm = (c.dot(&c)).sqrt();
655 if c_norm < 1e-10 {
656 return Err(StatsError::ComputationError(
657 "Y weights became zero".to_string(),
658 ));
659 }
660 let c = c / c_norm;
661
662 u = y_current.dot(&c);
664
665 let diff = (&w - &w_old).mapv(|x| x.abs()).sum();
667 if diff < self.tol {
668 converged_inner = true;
669 break;
670 }
671 w_old = w.clone();
672 }
673
674 if !converged_inner {
676 break;
677 }
678
679 let w = x_current.t().dot(&u);
681 let w_norm = (w.dot(&w)).sqrt();
682 if w_norm < 1e-10 {
683 break; }
685 let w = w.clone() / w_norm;
686 let t = x_current.dot(&w);
687 let c = y_current.t().dot(&t);
688 let c_norm = (c.dot(&c)).sqrt();
689 if c_norm < 1e-10 {
690 break; }
692 let c = c.clone() / c_norm;
693 let u = y_current.dot(&c);
694
695 let t_dot_t = t.dot(&t);
696 let u_dot_u = u.dot(&u);
697 if t_dot_t < 1e-10 || u_dot_u < 1e-10 {
698 break; }
700
701 let p = x_current.t().dot(&t) / t_dot_t;
702 let q = y_current.t().dot(&u) / u_dot_u;
703
704 x_weights.column_mut(comp).assign(&w);
706 y_weights.column_mut(comp).assign(&c);
707 x_loadings.column_mut(comp).assign(&p);
708 y_loadings.column_mut(comp).assign(&q);
709 x_scores.column_mut(comp).assign(&t);
710 y_scores.column_mut(comp).assign(&u);
711
712 actual_components += 1;
713
714 let _tt = Array1::from_vec(vec![t.dot(&t)]);
716 let outer_product = &t
717 .view()
718 .insert_axis(Axis(1))
719 .dot(&p.view().insert_axis(Axis(0)));
720 x_current -= outer_product;
721
722 let _uu = Array1::from_vec(vec![u.dot(&u)]);
723 let outer_product_y = &u
724 .view()
725 .insert_axis(Axis(1))
726 .dot(&q.view().insert_axis(Axis(0)));
727 y_current -= outer_product_y;
728 }
729
730 let x_weights = x_weights.slice(s![.., ..actual_components]).to_owned();
732 let y_weights = y_weights.slice(s![.., ..actual_components]).to_owned();
733 let x_loadings = x_loadings.slice(s![.., ..actual_components]).to_owned();
734 let y_loadings = y_loadings.slice(s![.., ..actual_components]).to_owned();
735 let x_scores = x_scores.slice(s![.., ..actual_components]).to_owned();
736 let y_scores = y_scores.slice(s![.., ..actual_components]).to_owned();
737
738 let (x_rotations, y_rotations) = if actual_components > 0 {
740 let x_rot = x_weights.dot(
741 &scirs2_linalg::inv(&(x_loadings.t().dot(&x_weights)).view(), None).map_err(
742 |e| {
743 StatsError::ComputationError(format!(
744 "Failed to compute X rotations: {}",
745 e
746 ))
747 },
748 )?,
749 );
750
751 let y_rot = y_weights.dot(
752 &scirs2_linalg::inv(&(y_loadings.t().dot(&y_weights)).view(), None).map_err(
753 |e| {
754 StatsError::ComputationError(format!(
755 "Failed to compute Y rotations: {}",
756 e
757 ))
758 },
759 )?,
760 );
761 (x_rot, y_rot)
762 } else {
763 (
764 Array2::zeros((n_x_features, 0)),
765 Array2::zeros((n_y_features, 0)),
766 )
767 };
768
769 Ok(PLSResult {
770 x_weights,
771 y_weights,
772 x_loadings,
773 y_loadings,
774 x_scores,
775 y_scores,
776 x_rotations,
777 y_rotations,
778 x_mean,
779 y_mean,
780 x_std,
781 y_std,
782 })
783 }
784
785 pub fn transform(&self, x: ArrayView2<f64>, result: &PLSResult) -> Result<Array2<f64>> {
787 let handler = global_error_handler();
788 validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "PLS transform");
789
790 if x.ncols() != result.x_mean.len() {
791 return Err(handler
792 .create_validation_error(
793 ErrorCode::E2001,
794 "PLS transform",
795 "n_features",
796 format!("input: {}, expected: {}", x.ncols(), result.x_mean.len()),
797 "Number of features must match training data",
798 )
799 .error);
800 }
801
802 let mut x_processed = x.to_owned();
804 for mut row in x_processed.rows_mut() {
805 row -= &result.x_mean;
806 }
807
808 if let Some(ref x_std) = result.x_std {
809 for mut row in x_processed.rows_mut() {
810 for j in 0..row.len() {
811 row[j] /= x_std[j];
812 }
813 }
814 }
815
816 Ok(x_processed.dot(&result.x_rotations))
817 }
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823 use scirs2_core::ndarray::array;
824
825 #[test]
826 fn test_cca_basic() {
827 let x = array![
828 [1.0, 2.0, 3.0],
829 [2.0, 3.0, 4.0],
830 [3.0, 4.0, 5.0],
831 [4.0, 5.0, 6.0],
832 [5.0, 6.0, 7.0],
833 ];
834
835 let y = array![
836 [2.0, 4.0],
837 [4.0, 6.0],
838 [6.0, 8.0],
839 [8.0, 10.0],
840 [10.0, 12.0],
841 ];
842
843 let cca = CanonicalCorrelationAnalysis::new().with_n_components(2);
844 let result = cca.fit(x.view(), y.view()).expect("Operation failed");
845
846 assert_eq!(result.n_components, 2);
847 assert_eq!(result.x_weights.ncols(), 2);
848 assert_eq!(result.y_weights.ncols(), 2);
849 assert_eq!(result.correlations.len(), 2);
850
851 let (x_canonical, y_canonical) = cca
853 .transform(x.view(), y.view(), &result)
854 .expect("Operation failed");
855 assert_eq!(x_canonical.nrows(), 5);
856 assert_eq!(y_canonical.nrows(), 5);
857 assert_eq!(x_canonical.ncols(), 2);
858 assert_eq!(y_canonical.ncols(), 2);
859 }
860
861 #[test]
862 fn test_pls_basic() {
863 let x = array![[1.0, 3.0], [2.0, 1.0], [3.0, 4.0], [4.0, 2.0], [5.0, 5.0],];
865
866 let y = array![[2.0, 6.0], [4.0, 2.0], [6.0, 8.0], [8.0, 4.0], [10.0, 10.0],];
867
868 let pls = PLSCanonical::new(2);
869 let result = pls.fit(x.view(), y.view()).expect("Operation failed");
870
871 assert_eq!(result.x_weights.ncols(), 2);
872 assert_eq!(result.y_weights.ncols(), 2);
873 assert_eq!(result.x_scores.nrows(), 5);
874 assert_eq!(result.y_scores.nrows(), 5);
875
876 let transformed = pls.transform(x.view(), &result).expect("Operation failed");
878 assert_eq!(transformed.nrows(), 5);
879 assert_eq!(transformed.ncols(), 2);
880 }
881}