Skip to main content

scirs2_stats/multivariate/
canonical_correlation.rs

1//! Canonical Correlation Analysis (CCA)
2//!
3//! CCA finds linear combinations of two sets of variables that are maximally correlated.
4//! It's useful for understanding relationships between two multivariate datasets.
5
6use crate::error::{StatsError, StatsResult as Result};
7use crate::error_handling_v2::ErrorCode;
8use crate::{unified_error_handling::global_error_handler, validate_or_error};
9use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
10use statrs::statistics::Statistics;
11
12/// Canonical Correlation Analysis
13///
14/// CCA finds linear combinations of variables in two datasets that have maximum correlation.
15/// This is useful for exploring relationships between two multivariate datasets.
16#[derive(Debug, Clone)]
17pub struct CanonicalCorrelationAnalysis {
18    /// Number of canonical components to compute
19    pub n_components: Option<usize>,
20    /// Whether to scale the data
21    pub scale: bool,
22    /// Regularization parameter for numerical stability
23    pub reg_param: f64,
24    /// Maximum number of iterations for iterative algorithms
25    pub max_iter: usize,
26    /// Convergence tolerance
27    pub tol: f64,
28}
29
30/// Result of Canonical Correlation Analysis
31#[derive(Debug, Clone)]
32pub struct CCAResult {
33    /// Canonical coefficients for first dataset (X)
34    pub x_weights: Array2<f64>,
35    /// Canonical coefficients for second dataset (Y)
36    pub y_weights: Array2<f64>,
37    /// Canonical correlations
38    pub correlations: Array1<f64>,
39    /// Canonical loadings for X (correlations between X variables and X canonical variates)
40    pub x_loadings: Array2<f64>,
41    /// Canonical loadings for Y (correlations between Y variables and Y canonical variates)
42    pub y_loadings: Array2<f64>,
43    /// Cross-loadings for X (correlations between X variables and Y canonical variates)
44    pub x_cross_loadings: Array2<f64>,
45    /// Cross-loadings for Y (correlations between Y variables and X canonical variates)
46    pub y_cross_loadings: Array2<f64>,
47    /// Means of X variables
48    pub x_mean: Array1<f64>,
49    /// Means of Y variables
50    pub y_mean: Array1<f64>,
51    /// Standard deviations of X variables (if scaled)
52    pub x_std: Option<Array1<f64>>,
53    /// Standard deviations of Y variables (if scaled)
54    pub y_std: Option<Array1<f64>>,
55    /// Number of components computed
56    pub n_components: usize,
57    /// Proportion of variance explained in X by each canonical component
58    pub x_explained_variance_ratio: Array1<f64>,
59    /// Proportion of variance explained in Y by each canonical component
60    pub y_explained_variance_ratio: Array1<f64>,
61}
62
63impl Default for CanonicalCorrelationAnalysis {
64    fn default() -> Self {
65        Self {
66            n_components: None,
67            scale: true,
68            reg_param: 1e-6,
69            max_iter: 500,
70            tol: 1e-8,
71        }
72    }
73}
74
75impl CanonicalCorrelationAnalysis {
76    /// Create a new CCA instance
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Set number of components to compute
82    pub fn with_n_components(mut self, ncomponents: usize) -> Self {
83        self.n_components = Some(ncomponents);
84        self
85    }
86
87    /// Set whether to scale the data
88    pub fn with_scale(mut self, scale: bool) -> Self {
89        self.scale = scale;
90        self
91    }
92
93    /// Set regularization parameter
94    pub fn with_reg_param(mut self, regparam: f64) -> Self {
95        self.reg_param = regparam;
96        self
97    }
98
99    /// Set maximum iterations
100    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
101        self.max_iter = maxiter;
102        self
103    }
104
105    /// Set convergence tolerance
106    pub fn with_tolerance(mut self, tol: f64) -> Self {
107        self.tol = tol;
108        self
109    }
110
111    /// Fit the CCA model
112    pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<CCAResult> {
113        let handler = global_error_handler();
114        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "CCA fit");
115        validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "CCA fit");
116
117        let (n_samples_x, n_features_x) = x.dim();
118        let (n_samples_y, n_features_y) = y.dim();
119
120        if n_samples_x != n_samples_y {
121            return Err(handler
122                .create_validation_error(
123                    ErrorCode::E2001,
124                    "CCA fit",
125                    "samplesize_mismatch",
126                    format!("x: {}, y: {}", n_samples_x, n_samples_y),
127                    "X and Y must have the same number of samples",
128                )
129                .error);
130        }
131
132        let n_samples_ = n_samples_x;
133        if n_samples_ < 2 {
134            return Err(handler
135                .create_validation_error(
136                    ErrorCode::E2003,
137                    "CCA fit",
138                    "n_samples_",
139                    n_samples_,
140                    "CCA requires at least 2 samples",
141                )
142                .error);
143        }
144
145        if n_features_x == 0 || n_features_y == 0 {
146            return Err(handler
147                .create_validation_error(
148                    ErrorCode::E2004,
149                    "CCA fit",
150                    "n_features",
151                    format!("x: {}, y: {}", n_features_x, n_features_y),
152                    "Both X and Y must have at least one feature",
153                )
154                .error);
155        }
156
157        // Determine number of components
158        let max_components = n_features_x.min(n_features_y).min(n_samples_ - 1);
159        let n_components = self
160            .n_components
161            .unwrap_or(max_components)
162            .min(max_components);
163
164        if n_components == 0 {
165            return Err(handler
166                .create_validation_error(
167                    ErrorCode::E1001,
168                    "CCA fit",
169                    "n_components",
170                    n_components,
171                    "Number of components must be positive",
172                )
173                .error);
174        }
175
176        // Center and optionally scale the data
177        let (x_centered, x_mean, x_std) = self.center_and_scale(x)?;
178        let (y_centered, y_mean, y_std) = self.center_and_scale(y)?;
179
180        // Compute cross-covariance and auto-covariance matrices
181        let (cxx, cyy, cxy) = self.compute_covariance_matrices(&x_centered, &y_centered)?;
182
183        // Solve the generalized eigenvalue problem
184        let (x_weights, y_weights, correlations) =
185            self.solve_cca_eigenvalue_problem(&cxx, &cyy, &cxy, n_components)?;
186
187        // Compute loadings and cross-loadings
188        let x_canonical = x_centered.dot(&x_weights);
189        let y_canonical = y_centered.dot(&y_weights);
190
191        let x_loadings = self.compute_loadings(&x_centered, &x_canonical)?;
192        let y_loadings = self.compute_loadings(&y_centered, &y_canonical)?;
193        let x_cross_loadings = self.compute_loadings(&x_centered, &y_canonical)?;
194        let y_cross_loadings = self.compute_loadings(&y_centered, &x_canonical)?;
195
196        // Compute explained variance ratios
197        let x_explained_variance_ratio =
198            self.compute_explained_variance(&x_centered, &x_canonical)?;
199        let y_explained_variance_ratio =
200            self.compute_explained_variance(&y_centered, &y_canonical)?;
201
202        Ok(CCAResult {
203            x_weights,
204            y_weights,
205            correlations,
206            x_loadings,
207            y_loadings,
208            x_cross_loadings,
209            y_cross_loadings,
210            x_mean,
211            y_mean,
212            x_std,
213            y_std,
214            n_components,
215            x_explained_variance_ratio,
216            y_explained_variance_ratio,
217        })
218    }
219
220    /// Center and optionally scale data
221    fn center_and_scale(
222        &self,
223        data: ArrayView2<f64>,
224    ) -> Result<(Array2<f64>, Array1<f64>, Option<Array1<f64>>)> {
225        let mean = data.mean_axis(Axis(0)).expect("Operation failed");
226        let mut centered = data.to_owned();
227
228        // Center data
229        for mut row in centered.rows_mut() {
230            row -= &mean;
231        }
232
233        if self.scale {
234            // Compute standard deviations
235            let mut std_dev = Array1::zeros(data.ncols());
236            for j in 0..data.ncols() {
237                let col = centered.column(j);
238                let variance = col.mapv(|x| x * x).mean();
239                std_dev[j] = variance.sqrt().max(1e-10); // Avoid division by zero
240            }
241
242            // Scale data
243            for mut row in centered.rows_mut() {
244                for j in 0..row.len() {
245                    row[j] /= std_dev[j];
246                }
247            }
248
249            Ok((centered, mean, Some(std_dev)))
250        } else {
251            Ok((centered, mean, None))
252        }
253    }
254
255    /// Compute covariance matrices
256    fn compute_covariance_matrices(
257        &self,
258        x: &Array2<f64>,
259        y: &Array2<f64>,
260    ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>)> {
261        let n_samples_ = x.nrows() as f64;
262
263        // Auto-covariance matrices
264        let cxx = x.t().dot(x) / (n_samples_ - 1.0);
265        let cyy = y.t().dot(y) / (n_samples_ - 1.0);
266
267        // Cross-covariance matrix
268        let cxy = x.t().dot(y) / (n_samples_ - 1.0);
269
270        Ok((cxx, cyy, cxy))
271    }
272
273    /// Solve the CCA eigenvalue problem
274    fn solve_cca_eigenvalue_problem(
275        &self,
276        cxx: &Array2<f64>,
277        cyy: &Array2<f64>,
278        cxy: &Array2<f64>,
279        n_components: usize,
280    ) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>)> {
281        // Regularized versions of covariance matrices
282        let cxx_reg = self.regularize_covariance(cxx)?;
283        let cyy_reg = self.regularize_covariance(cyy)?;
284
285        // Compute inverse square roots
286        let cxx_inv_sqrt = self.compute_inverse_sqrt(&cxx_reg)?;
287        let cyy_inv_sqrt = self.compute_inverse_sqrt(&cyy_reg)?;
288
289        // Form the matrix for SVD: Cxx^{-1/2} * Cxy * Cyy^{-1/2}
290        let k = cxx_inv_sqrt.dot(cxy).dot(&cyy_inv_sqrt);
291
292        // SVD of K using scirs2_linalg
293        let (u, s, vt) = scirs2_linalg::svd(&k.view(), true, None)
294            .map_err(|e| StatsError::ComputationError(format!("SVD failed in CCA: {}", e)))?;
295
296        // Extract the desired number of _components
297        let n_comp = n_components.min(s.len());
298        let correlations = s.slice(scirs2_core::ndarray::s![..n_comp]).to_owned();
299        let u_comp = u.slice(scirs2_core::ndarray::s![.., ..n_comp]).to_owned();
300        let v_comp = vt
301            .slice(scirs2_core::ndarray::s![..n_comp, ..])
302            .t()
303            .to_owned();
304
305        // Transform back to original space
306        let x_weights = cxx_inv_sqrt.dot(&u_comp);
307        let y_weights = cyy_inv_sqrt.dot(&v_comp);
308
309        Ok((x_weights, y_weights, correlations))
310    }
311
312    /// Regularize covariance matrix for numerical stability
313    fn regularize_covariance(&self, cov: &Array2<f64>) -> Result<Array2<f64>> {
314        if self.reg_param <= 0.0 {
315            return Ok(cov.clone());
316        }
317
318        let n = cov.nrows();
319        let trace = (0..n).map(|i| cov[[i, i]]).sum::<f64>();
320        let reg_term: Array2<f64> = Array2::eye(n) * (self.reg_param * trace / n as f64);
321
322        Ok(cov + &reg_term)
323    }
324
325    /// Compute inverse square root of a symmetric positive definite matrix
326    fn compute_inverse_sqrt(&self, matrix: &Array2<f64>) -> Result<Array2<f64>> {
327        // Eigenvalue decomposition using scirs2_linalg
328        let (eigenvalues, eigenvectors) =
329            scirs2_linalg::eigh_f64_lapack(&matrix.view()).map_err(|e| {
330                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
331            })?;
332
333        // Check for positive definiteness
334        let min_eigenvalue = eigenvalues.iter().cloned().fold(f64::INFINITY, f64::min);
335        if min_eigenvalue <= 1e-10 {
336            return Err(StatsError::ComputationError(format!(
337                "Matrix is not positive definite (min eigenvalue: {})",
338                min_eigenvalue
339            )));
340        }
341
342        // Compute inverse square root
343        let inv_sqrt_eigenvalues = eigenvalues.mapv(|x: f64| x.sqrt().recip());
344        let mut inv_sqrt = Array2::zeros(matrix.dim());
345
346        for i in 0..eigenvalues.len() {
347            let eigenvec = eigenvectors.column(i);
348            let lambda_inv_sqrt = inv_sqrt_eigenvalues[i];
349
350            for j in 0..matrix.nrows() {
351                for k in 0..matrix.ncols() {
352                    inv_sqrt[[j, k]] += lambda_inv_sqrt * eigenvec[j] * eigenvec[k];
353                }
354            }
355        }
356
357        Ok(inv_sqrt)
358    }
359
360    /// Compute loadings (correlations between original variables and canonical variates)
361    fn compute_loadings(
362        &self,
363        original: &Array2<f64>,
364        canonical: &Array2<f64>,
365    ) -> Result<Array2<f64>> {
366        let n_samples_ = original.nrows() as f64;
367        let n_original = original.ncols();
368        let n_canonical = canonical.ncols();
369
370        let mut loadings = Array2::zeros((n_original, n_canonical));
371
372        for i in 0..n_original {
373            let orig_var = original.column(i);
374            let orig_var_std = (orig_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
375
376            for j in 0..n_canonical {
377                let canon_var = canonical.column(j);
378                let canon_var_std = (canon_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
379
380                if orig_var_std > 1e-10 && canon_var_std > 1e-10 {
381                    let covariance = orig_var.dot(&canon_var) / (n_samples_ - 1.0);
382                    let correlation = covariance / (orig_var_std * canon_var_std);
383                    loadings[[i, j]] = correlation;
384                }
385            }
386        }
387
388        Ok(loadings)
389    }
390
391    /// Compute explained variance ratio
392    fn compute_explained_variance(
393        &self,
394        original: &Array2<f64>,
395        canonical: &Array2<f64>,
396    ) -> Result<Array1<f64>> {
397        let n_samples_ = original.nrows() as f64;
398        let n_canonical = canonical.ncols();
399
400        // Total variance in original variables
401        let total_variance = (0..original.ncols())
402            .map(|i| {
403                let col = original.column(i);
404                col.mapv(|x| x * x).sum() / (n_samples_ - 1.0)
405            })
406            .sum::<f64>();
407
408        if total_variance <= 1e-10 {
409            return Ok(Array1::zeros(n_canonical));
410        }
411
412        // Variance explained by each canonical component
413        let mut explained_variance = Array1::zeros(n_canonical);
414        for j in 0..n_canonical {
415            let canon_var = canonical.column(j);
416            let canon_variance = canon_var.mapv(|x| x * x).sum() / (n_samples_ - 1.0);
417            explained_variance[j] = canon_variance / total_variance;
418        }
419
420        Ok(explained_variance)
421    }
422
423    /// Transform new data using fitted CCA model
424    pub fn transform(
425        &self,
426        x: ArrayView2<f64>,
427        y: ArrayView2<f64>,
428        result: &CCAResult,
429    ) -> Result<(Array2<f64>, Array2<f64>)> {
430        let handler = global_error_handler();
431        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "CCA transform");
432        validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "CCA transform");
433
434        if x.ncols() != result.x_mean.len() {
435            return Err(handler
436                .create_validation_error(
437                    ErrorCode::E2001,
438                    "CCA transform",
439                    "x_features",
440                    format!("input: {}, expected: {}", x.ncols(), result.x_mean.len()),
441                    "X must have the same number of features as training data",
442                )
443                .error);
444        }
445
446        if y.ncols() != result.y_mean.len() {
447            return Err(handler
448                .create_validation_error(
449                    ErrorCode::E2001,
450                    "CCA transform",
451                    "y_features",
452                    format!("input: {}, expected: {}", y.ncols(), result.y_mean.len()),
453                    "Y must have the same number of features as training data",
454                )
455                .error);
456        }
457
458        // Center and scale X
459        let mut x_processed = x.to_owned();
460        for mut row in x_processed.rows_mut() {
461            row -= &result.x_mean;
462        }
463
464        if let Some(ref x_std) = result.x_std {
465            for mut row in x_processed.rows_mut() {
466                for j in 0..row.len() {
467                    row[j] /= x_std[j];
468                }
469            }
470        }
471
472        // Center and scale Y
473        let mut y_processed = y.to_owned();
474        for mut row in y_processed.rows_mut() {
475            row -= &result.y_mean;
476        }
477
478        if let Some(ref y_std) = result.y_std {
479            for mut row in y_processed.rows_mut() {
480                for j in 0..row.len() {
481                    row[j] /= y_std[j];
482                }
483            }
484        }
485
486        // Transform to canonical space
487        let x_canonical = x_processed.dot(&result.x_weights);
488        let y_canonical = y_processed.dot(&result.y_weights);
489
490        Ok((x_canonical, y_canonical))
491    }
492
493    /// Compute canonical correlations for new data
494    pub fn score(
495        &self,
496        x: ArrayView2<f64>,
497        y: ArrayView2<f64>,
498        result: &CCAResult,
499    ) -> Result<Array1<f64>> {
500        let (x_canonical, y_canonical) = self.transform(x, y, result)?;
501        let n_samples_ = x_canonical.nrows() as f64;
502        let n_components = result.n_components;
503
504        let mut correlations = Array1::zeros(n_components);
505        for i in 0..n_components {
506            let x_comp = x_canonical.column(i);
507            let y_comp = y_canonical.column(i);
508
509            let x_std = (x_comp.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
510            let y_std = (y_comp.mapv(|x| x * x).sum() / (n_samples_ - 1.0)).sqrt();
511
512            if x_std > 1e-10 && y_std > 1e-10 {
513                let covariance = x_comp.dot(&y_comp) / (n_samples_ - 1.0);
514                correlations[i] = covariance / (x_std * y_std);
515            }
516        }
517
518        Ok(correlations)
519    }
520}
521
522/// Partial Least Squares (PLS) regression variant of CCA
523///
524/// PLS is similar to CCA but optimized for prediction rather than just correlation.
525#[derive(Debug, Clone)]
526pub struct PLSCanonical {
527    /// Number of components
528    pub n_components: usize,
529    /// Whether to scale the data
530    pub scale: bool,
531    /// Maximum number of iterations
532    pub max_iter: usize,
533    /// Convergence tolerance
534    pub tol: f64,
535}
536
537/// Result of PLS Canonical analysis
538#[derive(Debug, Clone)]
539pub struct PLSResult {
540    /// X weights
541    pub x_weights: Array2<f64>,
542    /// Y weights
543    pub y_weights: Array2<f64>,
544    /// X loadings
545    pub x_loadings: Array2<f64>,
546    /// Y loadings
547    pub y_loadings: Array2<f64>,
548    /// X scores
549    pub x_scores: Array2<f64>,
550    /// Y scores
551    pub y_scores: Array2<f64>,
552    /// X rotation matrix
553    pub x_rotations: Array2<f64>,
554    /// Y rotation matrix
555    pub y_rotations: Array2<f64>,
556    /// Means
557    pub x_mean: Array1<f64>,
558    pub y_mean: Array1<f64>,
559    /// Standard deviations (if scaled)
560    pub x_std: Option<Array1<f64>>,
561    pub y_std: Option<Array1<f64>>,
562}
563
564impl Default for PLSCanonical {
565    fn default() -> Self {
566        Self {
567            n_components: 2,
568            scale: true,
569            max_iter: 500,
570            tol: 1e-6,
571        }
572    }
573}
574
575impl PLSCanonical {
576    /// Create new PLS instance
577    pub fn new(_ncomponents: usize) -> Self {
578        Self {
579            n_components: _ncomponents,
580            ..Default::default()
581        }
582    }
583
584    /// Fit PLS model using NIPALS algorithm
585    pub fn fit(&self, x: ArrayView2<f64>, y: ArrayView2<f64>) -> Result<PLSResult> {
586        let handler = global_error_handler();
587        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "PLS fit");
588        validate_or_error!(finite: y.as_slice().expect("Operation failed"), "y", "PLS fit");
589
590        let (n_samples_, n_x_features) = x.dim();
591        let (n_samples_y, n_y_features) = y.dim();
592
593        if n_samples_ != n_samples_y {
594            return Err(handler
595                .create_validation_error(
596                    ErrorCode::E2001,
597                    "PLS fit",
598                    "samplesize_mismatch",
599                    format!("x: {}, y: {}", n_samples_, n_samples_y),
600                    "X and Y must have the same number of samples",
601                )
602                .error);
603        }
604
605        // Center and scale data
606        let cca = CanonicalCorrelationAnalysis {
607            scale: self.scale,
608            ..Default::default()
609        };
610        let (mut x_current, x_mean, x_std) = cca.center_and_scale(x)?;
611        let (mut y_current, y_mean, y_std) = cca.center_and_scale(y)?;
612
613        // Initialize result matrices
614        let mut x_weights = Array2::zeros((n_x_features, self.n_components));
615        let mut y_weights = Array2::zeros((n_y_features, self.n_components));
616        let mut x_loadings = Array2::zeros((n_x_features, self.n_components));
617        let mut y_loadings = Array2::zeros((n_y_features, self.n_components));
618        let mut x_scores = Array2::zeros((n_samples_, self.n_components));
619        let mut y_scores = Array2::zeros((n_samples_, self.n_components));
620
621        // NIPALS algorithm
622        let mut actual_components = 0;
623        for comp in 0..self.n_components {
624            // Check if there's sufficient variance left for another component
625            let x_var = x_current.iter().map(|&x| x * x).sum::<f64>();
626            let y_var = y_current.iter().map(|&y| y * y).sum::<f64>();
627
628            if x_var < 1e-12 || y_var < 1e-12 {
629                // Not enough variance left, stop here
630                break;
631            }
632
633            // Initialize weights with first column of Y
634            let mut u = y_current.column(0).to_owned();
635            let mut w_old = Array1::zeros(n_x_features);
636
637            let mut converged_inner = false;
638            for _iter in 0..self.max_iter {
639                // X weights
640                let w = x_current.t().dot(&u);
641                let w_norm = (w.dot(&w)).sqrt();
642                if w_norm < 1e-10 {
643                    // No more meaningful components can be extracted
644                    converged_inner = false;
645                    break;
646                }
647                let w = w / w_norm;
648
649                // X scores
650                let t = x_current.dot(&w);
651
652                // Y weights
653                let c = y_current.t().dot(&t);
654                let c_norm = (c.dot(&c)).sqrt();
655                if c_norm < 1e-10 {
656                    return Err(StatsError::ComputationError(
657                        "Y weights became zero".to_string(),
658                    ));
659                }
660                let c = c / c_norm;
661
662                // Y scores
663                u = y_current.dot(&c);
664
665                // Check convergence
666                let diff = (&w - &w_old).mapv(|x| x.abs()).sum();
667                if diff < self.tol {
668                    converged_inner = true;
669                    break;
670                }
671                w_old = w.clone();
672            }
673
674            // If inner loop didn't converge, skip this component
675            if !converged_inner {
676                break;
677            }
678
679            // Compute loadings
680            let w = x_current.t().dot(&u);
681            let w_norm = (w.dot(&w)).sqrt();
682            if w_norm < 1e-10 {
683                break; // Can't extract this component
684            }
685            let w = w.clone() / w_norm;
686            let t = x_current.dot(&w);
687            let c = y_current.t().dot(&t);
688            let c_norm = (c.dot(&c)).sqrt();
689            if c_norm < 1e-10 {
690                break; // Can't extract this component
691            }
692            let c = c.clone() / c_norm;
693            let u = y_current.dot(&c);
694
695            let t_dot_t = t.dot(&t);
696            let u_dot_u = u.dot(&u);
697            if t_dot_t < 1e-10 || u_dot_u < 1e-10 {
698                break; // Can't extract this component
699            }
700
701            let p = x_current.t().dot(&t) / t_dot_t;
702            let q = y_current.t().dot(&u) / u_dot_u;
703
704            // Store results
705            x_weights.column_mut(comp).assign(&w);
706            y_weights.column_mut(comp).assign(&c);
707            x_loadings.column_mut(comp).assign(&p);
708            y_loadings.column_mut(comp).assign(&q);
709            x_scores.column_mut(comp).assign(&t);
710            y_scores.column_mut(comp).assign(&u);
711
712            actual_components += 1;
713
714            // Deflate matrices
715            let _tt = Array1::from_vec(vec![t.dot(&t)]);
716            let outer_product = &t
717                .view()
718                .insert_axis(Axis(1))
719                .dot(&p.view().insert_axis(Axis(0)));
720            x_current -= outer_product;
721
722            let _uu = Array1::from_vec(vec![u.dot(&u)]);
723            let outer_product_y = &u
724                .view()
725                .insert_axis(Axis(1))
726                .dot(&q.view().insert_axis(Axis(0)));
727            y_current -= outer_product_y;
728        }
729
730        // Slice matrices to actual components extracted
731        let x_weights = x_weights.slice(s![.., ..actual_components]).to_owned();
732        let y_weights = y_weights.slice(s![.., ..actual_components]).to_owned();
733        let x_loadings = x_loadings.slice(s![.., ..actual_components]).to_owned();
734        let y_loadings = y_loadings.slice(s![.., ..actual_components]).to_owned();
735        let x_scores = x_scores.slice(s![.., ..actual_components]).to_owned();
736        let y_scores = y_scores.slice(s![.., ..actual_components]).to_owned();
737
738        // Compute rotation matrices only if we have components
739        let (x_rotations, y_rotations) = if actual_components > 0 {
740            let x_rot = x_weights.dot(
741                &scirs2_linalg::inv(&(x_loadings.t().dot(&x_weights)).view(), None).map_err(
742                    |e| {
743                        StatsError::ComputationError(format!(
744                            "Failed to compute X rotations: {}",
745                            e
746                        ))
747                    },
748                )?,
749            );
750
751            let y_rot = y_weights.dot(
752                &scirs2_linalg::inv(&(y_loadings.t().dot(&y_weights)).view(), None).map_err(
753                    |e| {
754                        StatsError::ComputationError(format!(
755                            "Failed to compute Y rotations: {}",
756                            e
757                        ))
758                    },
759                )?,
760            );
761            (x_rot, y_rot)
762        } else {
763            (
764                Array2::zeros((n_x_features, 0)),
765                Array2::zeros((n_y_features, 0)),
766            )
767        };
768
769        Ok(PLSResult {
770            x_weights,
771            y_weights,
772            x_loadings,
773            y_loadings,
774            x_scores,
775            y_scores,
776            x_rotations,
777            y_rotations,
778            x_mean,
779            y_mean,
780            x_std,
781            y_std,
782        })
783    }
784
785    /// Transform new data
786    pub fn transform(&self, x: ArrayView2<f64>, result: &PLSResult) -> Result<Array2<f64>> {
787        let handler = global_error_handler();
788        validate_or_error!(finite: x.as_slice().expect("Operation failed"), "x", "PLS transform");
789
790        if x.ncols() != result.x_mean.len() {
791            return Err(handler
792                .create_validation_error(
793                    ErrorCode::E2001,
794                    "PLS transform",
795                    "n_features",
796                    format!("input: {}, expected: {}", x.ncols(), result.x_mean.len()),
797                    "Number of features must match training data",
798                )
799                .error);
800        }
801
802        // Center and scale
803        let mut x_processed = x.to_owned();
804        for mut row in x_processed.rows_mut() {
805            row -= &result.x_mean;
806        }
807
808        if let Some(ref x_std) = result.x_std {
809            for mut row in x_processed.rows_mut() {
810                for j in 0..row.len() {
811                    row[j] /= x_std[j];
812                }
813            }
814        }
815
816        Ok(x_processed.dot(&result.x_rotations))
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use scirs2_core::ndarray::array;
824
825    #[test]
826    fn test_cca_basic() {
827        let x = array![
828            [1.0, 2.0, 3.0],
829            [2.0, 3.0, 4.0],
830            [3.0, 4.0, 5.0],
831            [4.0, 5.0, 6.0],
832            [5.0, 6.0, 7.0],
833        ];
834
835        let y = array![
836            [2.0, 4.0],
837            [4.0, 6.0],
838            [6.0, 8.0],
839            [8.0, 10.0],
840            [10.0, 12.0],
841        ];
842
843        let cca = CanonicalCorrelationAnalysis::new().with_n_components(2);
844        let result = cca.fit(x.view(), y.view()).expect("Operation failed");
845
846        assert_eq!(result.n_components, 2);
847        assert_eq!(result.x_weights.ncols(), 2);
848        assert_eq!(result.y_weights.ncols(), 2);
849        assert_eq!(result.correlations.len(), 2);
850
851        // Test transformation
852        let (x_canonical, y_canonical) = cca
853            .transform(x.view(), y.view(), &result)
854            .expect("Operation failed");
855        assert_eq!(x_canonical.nrows(), 5);
856        assert_eq!(y_canonical.nrows(), 5);
857        assert_eq!(x_canonical.ncols(), 2);
858        assert_eq!(y_canonical.ncols(), 2);
859    }
860
861    #[test]
862    fn test_pls_basic() {
863        // Create data with more independent variation to support 2 components
864        let x = array![[1.0, 3.0], [2.0, 1.0], [3.0, 4.0], [4.0, 2.0], [5.0, 5.0],];
865
866        let y = array![[2.0, 6.0], [4.0, 2.0], [6.0, 8.0], [8.0, 4.0], [10.0, 10.0],];
867
868        let pls = PLSCanonical::new(2);
869        let result = pls.fit(x.view(), y.view()).expect("Operation failed");
870
871        assert_eq!(result.x_weights.ncols(), 2);
872        assert_eq!(result.y_weights.ncols(), 2);
873        assert_eq!(result.x_scores.nrows(), 5);
874        assert_eq!(result.y_scores.nrows(), 5);
875
876        // Test transformation
877        let transformed = pls.transform(x.view(), &result).expect("Operation failed");
878        assert_eq!(transformed.nrows(), 5);
879        assert_eq!(transformed.ncols(), 2);
880    }
881}