sklears_cross_decomposition/
multitask.rs

1//! Multi-task learning methods for cross-decomposition
2//!
3//! This module implements various multi-task learning approaches for cross-decomposition,
4//! including multi-task CCA, shared component analysis, transfer learning, domain adaptation,
5//! and few-shot learning methods.
6
7use scirs2_core::ndarray::{s, Array1, Array2, Axis};
8use scirs2_core::rand_prelude::SliceRandom;
9use scirs2_core::random::{thread_rng, Random, Rng};
10use sklears_core::error::SklearsError;
11use sklears_core::traits::Estimator;
12use std::collections::HashMap;
13
14/// Multi-task Canonical Correlation Analysis
15///
16/// Multi-task CCA learns canonical correlations across multiple related tasks,
17/// sharing information between tasks to improve performance on individual tasks.
18#[derive(Debug, Clone)]
19pub struct MultiTaskCCA {
20    n_components: usize,
21    reg_param: f64,
22    max_iter: usize,
23    tol: f64,
24    sharing_strength: f64,
25    canonical_weights_x: Option<Array2<f64>>,
26    canonical_weights_y: Option<Array2<f64>>,
27    shared_components: Option<Array2<f64>>,
28    task_specific_components: Option<HashMap<usize, Array2<f64>>>,
29    correlations: Option<Array1<f64>>,
30}
31
32impl MultiTaskCCA {
33    /// Creates a new MultiTaskCCA instance
34    pub fn new(n_components: usize, reg_param: f64, sharing_strength: f64) -> Self {
35        Self {
36            n_components,
37            reg_param,
38            max_iter: 500,
39            tol: 1e-6,
40            sharing_strength,
41            canonical_weights_x: None,
42            canonical_weights_y: None,
43            shared_components: None,
44            task_specific_components: None,
45            correlations: None,
46        }
47    }
48
49    /// Sets the maximum number of iterations
50    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
51        self.max_iter = max_iter;
52        self
53    }
54
55    /// Sets the convergence tolerance
56    pub fn with_tolerance(mut self, tol: f64) -> Self {
57        self.tol = tol;
58        self
59    }
60
61    /// Fits multi-task CCA on multiple datasets
62    pub fn fit_multi_task(
63        &self,
64        x_tasks: &[Array2<f64>],
65        y_tasks: &[Array2<f64>],
66    ) -> Result<Self, SklearsError> {
67        if x_tasks.len() != y_tasks.len() {
68            return Err(SklearsError::InvalidInput(
69                "Number of X and Y tasks must match".to_string(),
70            ));
71        }
72
73        if x_tasks.is_empty() {
74            return Err(SklearsError::InvalidInput(
75                "At least one task must be provided".to_string(),
76            ));
77        }
78
79        let n_tasks = x_tasks.len();
80        let n_features_x = x_tasks[0].shape()[1];
81        let n_features_y = y_tasks[0].shape()[1];
82
83        // Initialize shared and task-specific components
84        let mut shared_wx = Array2::zeros((n_features_x, self.n_components));
85        let mut shared_wy = Array2::zeros((n_features_y, self.n_components));
86        let mut task_specific_wx = HashMap::new();
87        let mut task_specific_wy = HashMap::new();
88
89        // Initialize task-specific components
90        for task_id in 0..n_tasks {
91            task_specific_wx.insert(task_id, Array2::zeros((n_features_x, self.n_components)));
92            task_specific_wy.insert(task_id, Array2::zeros((n_features_y, self.n_components)));
93        }
94
95        // Alternating optimization
96        for iter in 0..self.max_iter {
97            let mut converged = true;
98            let old_shared_wx = shared_wx.clone();
99
100            // Update shared components
101            for comp in 0..self.n_components {
102                let mut cov_xx_shared = Array2::zeros((n_features_x, n_features_x));
103                let mut cov_xy_shared = Array2::zeros((n_features_x, n_features_y));
104                let mut cov_yy_shared = Array2::zeros((n_features_y, n_features_y));
105
106                // Aggregate covariances across tasks
107                for (task_id, (x_task, y_task)) in x_tasks.iter().zip(y_tasks.iter()).enumerate() {
108                    let x_centered = self.center_data(x_task)?;
109                    let y_centered = self.center_data(y_task)?;
110
111                    let task_wx = &task_specific_wx[&task_id];
112                    let task_wy = &task_specific_wy[&task_id];
113
114                    // Compute task-specific residuals
115                    let x_proj = x_centered.dot(task_wx);
116                    let x_recon = x_proj.dot(&task_wx.t());
117                    let x_residual = &x_centered - &x_recon;
118
119                    let y_proj = y_centered.dot(task_wy);
120                    let y_recon = y_proj.dot(&task_wy.t());
121                    let y_residual = &y_centered - &y_recon;
122
123                    cov_xx_shared =
124                        cov_xx_shared + x_residual.t().dot(&x_residual) / x_task.shape()[0] as f64;
125                    cov_xy_shared =
126                        cov_xy_shared + x_residual.t().dot(&y_residual) / x_task.shape()[0] as f64;
127                    cov_yy_shared =
128                        cov_yy_shared + y_residual.t().dot(&y_residual) / y_task.shape()[0] as f64;
129                }
130
131                // Add regularization
132                cov_xx_shared
133                    .diag_mut()
134                    .mapv_inplace(|x| x + self.reg_param);
135                cov_yy_shared
136                    .diag_mut()
137                    .mapv_inplace(|x| x + self.reg_param);
138
139                // Solve generalized eigenvalue problem for shared components
140                let (eigvals, eigvecs_x, eigvecs_y) = self.solve_generalized_eigenvalue(
141                    &cov_xy_shared,
142                    &cov_xx_shared,
143                    &cov_yy_shared,
144                )?;
145
146                if let Some(max_idx) = eigvals
147                    .iter()
148                    .enumerate()
149                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
150                    .map(|(idx, _)| idx)
151                {
152                    shared_wx
153                        .column_mut(comp)
154                        .assign(&eigvecs_x.column(max_idx));
155                    shared_wy
156                        .column_mut(comp)
157                        .assign(&eigvecs_y.column(max_idx));
158                }
159            }
160
161            // Update task-specific components
162            for (task_id, (x_task, y_task)) in x_tasks.iter().zip(y_tasks.iter()).enumerate() {
163                let x_centered = self.center_data(x_task)?;
164                let y_centered = self.center_data(y_task)?;
165
166                // Remove shared component contribution
167                let x_shared_proj = x_centered.dot(&shared_wx);
168                let x_shared_recon = x_shared_proj.dot(&shared_wx.t());
169                let x_residual = &x_centered - &x_shared_recon;
170
171                let y_shared_proj = y_centered.dot(&shared_wy);
172                let y_shared_recon = y_shared_proj.dot(&shared_wy.t());
173                let y_residual = &y_centered - &y_shared_recon;
174
175                // Compute task-specific CCA
176                let cov_xx = x_residual.t().dot(&x_residual) / x_task.shape()[0] as f64;
177                let cov_xy = x_residual.t().dot(&y_residual) / x_task.shape()[0] as f64;
178                let cov_yy = y_residual.t().dot(&y_residual) / y_task.shape()[0] as f64;
179
180                let mut cov_xx_reg = cov_xx.clone();
181                let mut cov_yy_reg = cov_yy.clone();
182                cov_xx_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
183                cov_yy_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
184
185                let (_, eigvecs_x, eigvecs_y) =
186                    self.solve_generalized_eigenvalue(&cov_xy, &cov_xx_reg, &cov_yy_reg)?;
187
188                let n_comps = self.n_components.min(eigvecs_x.shape()[1]);
189                if let Some(task_wx) = task_specific_wx.get_mut(&task_id) {
190                    task_wx
191                        .slice_mut(s![.., ..n_comps])
192                        .assign(&eigvecs_x.slice(s![.., ..n_comps]));
193                }
194                if let Some(task_wy) = task_specific_wy.get_mut(&task_id) {
195                    task_wy
196                        .slice_mut(s![.., ..n_comps])
197                        .assign(&eigvecs_y.slice(s![.., ..n_comps]));
198                }
199            }
200
201            // Check convergence
202            let diff = (&shared_wx - &old_shared_wx).mapv(|x| x.abs()).sum();
203            if diff < self.tol {
204                converged = true;
205                break;
206            }
207
208            if iter == self.max_iter - 1 && !converged {
209                return Err(SklearsError::ConvergenceError {
210                    iterations: self.max_iter,
211                });
212            }
213        }
214
215        // Compute correlations for shared components
216        let mut correlations = Array1::zeros(self.n_components);
217        for comp in 0..self.n_components {
218            let mut total_corr = 0.0;
219            for (x_task, y_task) in x_tasks.iter().zip(y_tasks.iter()) {
220                let x_centered = self.center_data(x_task)?;
221                let y_centered = self.center_data(y_task)?;
222
223                let x_proj = x_centered.dot(&shared_wx.column(comp));
224                let y_proj = y_centered.dot(&shared_wy.column(comp));
225
226                let corr = self.compute_correlation(&x_proj, &y_proj)?;
227                total_corr += corr;
228            }
229            correlations[comp] = total_corr / n_tasks as f64;
230        }
231
232        Ok(Self {
233            canonical_weights_x: Some(shared_wx),
234            canonical_weights_y: Some(shared_wy),
235            shared_components: Some(Array2::zeros((self.n_components, self.n_components))),
236            task_specific_components: Some(task_specific_wx),
237            correlations: Some(correlations),
238            ..self.clone()
239        })
240    }
241
242    fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
243        let mean = data
244            .mean_axis(Axis(0))
245            .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
246        Ok(data - &mean)
247    }
248
249    fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
250        let n = x.len() as f64;
251        if n < 2.0 {
252            return Ok(0.0);
253        }
254
255        let mean_x = x.sum() / n;
256        let mean_y = y.sum() / n;
257
258        let mut cov = 0.0;
259        let mut var_x = 0.0;
260        let mut var_y = 0.0;
261
262        for i in 0..x.len() {
263            let dx = x[i] - mean_x;
264            let dy = y[i] - mean_y;
265            cov += dx * dy;
266            var_x += dx * dx;
267            var_y += dy * dy;
268        }
269
270        let denom = (var_x * var_y).sqrt();
271        if denom.abs() < 1e-12 {
272            Ok(0.0)
273        } else {
274            Ok(cov / denom)
275        }
276    }
277
278    fn solve_generalized_eigenvalue(
279        &self,
280        cov_xy: &Array2<f64>,
281        cov_xx: &Array2<f64>,
282        cov_yy: &Array2<f64>,
283    ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
284        // Simplified eigenvalue decomposition
285        // In a real implementation, this would use proper LAPACK routines
286        let n_features = cov_xx.shape()[0];
287        let n_comps = self.n_components.min(n_features);
288
289        let mut rng = thread_rng();
290        let mut eigvecs_x = Array2::zeros((n_features, n_comps));
291        let mut eigvecs_y = Array2::zeros((cov_yy.shape()[0], n_comps));
292        let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
293
294        // Initialize with random orthogonal vectors
295        for i in 0..n_comps {
296            for j in 0..n_features {
297                eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
298            }
299            for j in 0..cov_yy.shape()[0] {
300                eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
301            }
302        }
303
304        // Normalize columns
305        for i in 0..n_comps {
306            let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
307            let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
308            if norm_x > 1e-12 {
309                eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
310            }
311            if norm_y > 1e-12 {
312                eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
313            }
314        }
315
316        Ok((eigvals, eigvecs_x, eigvecs_y))
317    }
318
319    /// Gets the shared canonical weights for X
320    pub fn shared_weights_x(&self) -> Option<&Array2<f64>> {
321        self.canonical_weights_x.as_ref()
322    }
323
324    /// Gets the shared canonical weights for Y
325    pub fn shared_weights_y(&self) -> Option<&Array2<f64>> {
326        self.canonical_weights_y.as_ref()
327    }
328
329    /// Gets the task-specific weights for a given task
330    pub fn task_weights(&self, task_id: usize) -> Option<&Array2<f64>> {
331        self.task_specific_components.as_ref()?.get(&task_id)
332    }
333
334    /// Gets the canonical correlations
335    pub fn correlations(&self) -> Option<&Array1<f64>> {
336        self.correlations.as_ref()
337    }
338}
339
340/// Shared Component Analysis
341///
342/// Identifies components that are shared across multiple datasets/tasks
343/// and components that are specific to individual tasks.
344#[derive(Debug, Clone)]
345pub struct SharedComponentAnalysis {
346    n_shared_components: usize,
347    n_specific_components: usize,
348    reg_param: f64,
349    max_iter: usize,
350    tol: f64,
351    shared_components: Option<Array2<f64>>,
352    specific_components: Option<HashMap<usize, Array2<f64>>>,
353    explained_variance_shared: Option<Array1<f64>>,
354    explained_variance_specific: Option<HashMap<usize, Array1<f64>>>,
355}
356
357impl SharedComponentAnalysis {
358    /// Creates a new SharedComponentAnalysis instance
359    pub fn new(n_shared_components: usize, n_specific_components: usize, reg_param: f64) -> Self {
360        Self {
361            n_shared_components,
362            n_specific_components,
363            reg_param,
364            max_iter: 100,
365            tol: 1e-3,
366            shared_components: None,
367            specific_components: None,
368            explained_variance_shared: None,
369            explained_variance_specific: None,
370        }
371    }
372
373    /// Fits shared component analysis on multiple datasets
374    pub fn fit_datasets(&self, datasets: &[Array2<f64>]) -> Result<Self, SklearsError> {
375        if datasets.is_empty() {
376            return Err(SklearsError::InvalidInput(
377                "At least one dataset must be provided".to_string(),
378            ));
379        }
380
381        let n_tasks = datasets.len();
382        let n_features = datasets[0].shape()[1];
383
384        // Center all datasets
385        let mut centered_datasets = Vec::new();
386        for dataset in datasets {
387            let centered = self.center_data(dataset)?;
388            centered_datasets.push(centered);
389        }
390
391        // Initialize shared and specific components
392        let mut shared_comps = Array2::zeros((n_features, self.n_shared_components));
393        let mut specific_comps = HashMap::new();
394
395        for task_id in 0..n_tasks {
396            specific_comps.insert(
397                task_id,
398                Array2::zeros((n_features, self.n_specific_components)),
399            );
400        }
401
402        // Random initialization
403        let mut rng = thread_rng();
404        shared_comps.mapv_inplace(|_| rng.gen_range(-1.0..1.0));
405        for comps in specific_comps.values_mut() {
406            comps.mapv_inplace(|_| rng.gen_range(-1.0..1.0));
407        }
408
409        // Simplified approach: just compute PCA on averaged covariance
410        let mut total_cov = Array2::zeros((n_features, n_features));
411        for dataset in &centered_datasets {
412            let cov = dataset.t().dot(dataset) / dataset.shape()[0] as f64;
413            total_cov = total_cov + cov;
414        }
415        total_cov = total_cov / n_tasks as f64;
416
417        // Add regularization
418        total_cov.diag_mut().mapv_inplace(|x| x + self.reg_param);
419
420        // Compute shared components from averaged covariance
421        shared_comps = self.compute_principal_components(&total_cov, self.n_shared_components)?;
422
423        // Compute specific components for each task
424        for (task_id, dataset) in centered_datasets.iter().enumerate() {
425            let shared_proj = dataset.dot(&shared_comps);
426            let shared_recon = shared_proj.dot(&shared_comps.t());
427            let residual = dataset - &shared_recon;
428
429            let specific_cov = residual.t().dot(&residual) / dataset.shape()[0] as f64;
430            let mut specific_cov_reg = specific_cov.clone();
431            specific_cov_reg
432                .diag_mut()
433                .mapv_inplace(|x| x + self.reg_param);
434
435            let specific_pc =
436                self.compute_principal_components(&specific_cov_reg, self.n_specific_components)?;
437            specific_comps.insert(task_id, specific_pc);
438        }
439
440        // Compute explained variance
441        let mut shared_variance = Array1::zeros(self.n_shared_components);
442        let mut specific_variance = HashMap::new();
443
444        for (task_id, dataset) in centered_datasets.iter().enumerate() {
445            // Shared variance
446            let shared_proj = dataset.dot(&shared_comps);
447            for comp in 0..self.n_shared_components {
448                let var =
449                    shared_proj.column(comp).mapv(|x| x * x).sum() / dataset.shape()[0] as f64;
450                shared_variance[comp] += var;
451            }
452
453            // Specific variance
454            let specific = &specific_comps[&task_id];
455            let specific_proj = dataset.dot(specific);
456            let mut specific_var = Array1::zeros(self.n_specific_components);
457            for comp in 0..self.n_specific_components {
458                let var =
459                    specific_proj.column(comp).mapv(|x| x * x).sum() / dataset.shape()[0] as f64;
460                specific_var[comp] = var;
461            }
462            specific_variance.insert(task_id, specific_var);
463        }
464
465        // Average shared variance across tasks
466        shared_variance.mapv_inplace(|x| x / n_tasks as f64);
467
468        Ok(Self {
469            shared_components: Some(shared_comps),
470            specific_components: Some(specific_comps),
471            explained_variance_shared: Some(shared_variance),
472            explained_variance_specific: Some(specific_variance),
473            ..self.clone()
474        })
475    }
476
477    fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
478        let mean = data
479            .mean_axis(Axis(0))
480            .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
481        Ok(data - &mean)
482    }
483
484    fn compute_principal_components(
485        &self,
486        cov_matrix: &Array2<f64>,
487        n_components: usize,
488    ) -> Result<Array2<f64>, SklearsError> {
489        let n_features = cov_matrix.shape()[0];
490        let n_comps = n_components.min(n_features);
491
492        // Simplified PCA (placeholder for real eigenvalue decomposition)
493        let mut rng = thread_rng();
494        let mut components = Array2::zeros((n_features, n_comps));
495
496        for i in 0..n_comps {
497            for j in 0..n_features {
498                components[[j, i]] = rng.gen_range(-1.0..1.0);
499            }
500            // Normalize
501            let norm = (components.column(i).mapv(|x| x * x).sum() as f64).sqrt();
502            if norm > 1e-12 {
503                components.column_mut(i).mapv_inplace(|x| x / norm);
504            }
505        }
506
507        Ok(components)
508    }
509
510    /// Gets the shared components
511    pub fn shared_components(&self) -> Option<&Array2<f64>> {
512        self.shared_components.as_ref()
513    }
514
515    /// Gets the specific components for a task
516    pub fn specific_components(&self, task_id: usize) -> Option<&Array2<f64>> {
517        self.specific_components.as_ref()?.get(&task_id)
518    }
519
520    /// Gets the explained variance for shared components
521    pub fn explained_variance_shared(&self) -> Option<&Array1<f64>> {
522        self.explained_variance_shared.as_ref()
523    }
524
525    /// Gets the explained variance for specific components of a task
526    pub fn explained_variance_specific(&self, task_id: usize) -> Option<&Array1<f64>> {
527        self.explained_variance_specific.as_ref()?.get(&task_id)
528    }
529}
530
531/// Transfer Learning for Cross-Decomposition
532///
533/// Transfers knowledge from source tasks to target tasks using cross-decomposition methods.
534#[derive(Debug, Clone)]
535pub struct TransferLearningCCA {
536    n_components: usize,
537    reg_param: f64,
538    transfer_strength: f64,
539    max_iter: usize,
540    tol: f64,
541    source_weights_x: Option<Array2<f64>>,
542    source_weights_y: Option<Array2<f64>>,
543    target_weights_x: Option<Array2<f64>>,
544    target_weights_y: Option<Array2<f64>>,
545    transfer_matrix: Option<Array2<f64>>,
546    correlations: Option<Array1<f64>>,
547}
548
549impl TransferLearningCCA {
550    /// Creates a new TransferLearningCCA instance
551    pub fn new(n_components: usize, reg_param: f64, transfer_strength: f64) -> Self {
552        Self {
553            n_components,
554            reg_param,
555            transfer_strength,
556            max_iter: 500,
557            tol: 1e-6,
558            source_weights_x: None,
559            source_weights_y: None,
560            target_weights_x: None,
561            target_weights_y: None,
562            transfer_matrix: None,
563            correlations: None,
564        }
565    }
566
567    /// First fits on source domain, then transfers to target domain
568    pub fn fit_transfer(
569        &self,
570        source_x: &Array2<f64>,
571        source_y: &Array2<f64>,
572        target_x: &Array2<f64>,
573        target_y: &Array2<f64>,
574    ) -> Result<Self, SklearsError> {
575        // Step 1: Learn source domain CCA
576        let source_result = self.fit_source_domain(source_x, source_y)?;
577
578        // Step 2: Transfer to target domain
579        let target_result = self.transfer_to_target_domain(&source_result, target_x, target_y)?;
580
581        Ok(target_result)
582    }
583
584    fn fit_source_domain(
585        &self,
586        source_x: &Array2<f64>,
587        source_y: &Array2<f64>,
588    ) -> Result<Self, SklearsError> {
589        // Center the data
590        let x_centered = self.center_data(source_x)?;
591        let y_centered = self.center_data(source_y)?;
592
593        // Compute covariance matrices
594        let n_samples = source_x.shape()[0] as f64;
595        let cov_xx = x_centered.t().dot(&x_centered) / n_samples;
596        let cov_xy = x_centered.t().dot(&y_centered) / n_samples;
597        let cov_yy = y_centered.t().dot(&y_centered) / n_samples;
598
599        // Add regularization
600        let mut cov_xx_reg = cov_xx.clone();
601        let mut cov_yy_reg = cov_yy.clone();
602        cov_xx_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
603        cov_yy_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
604
605        // Solve generalized eigenvalue problem
606        let (eigvals, eigvecs_x, eigvecs_y) =
607            self.solve_generalized_eigenvalue(&cov_xy, &cov_xx_reg, &cov_yy_reg)?;
608
609        Ok(Self {
610            source_weights_x: Some(eigvecs_x),
611            source_weights_y: Some(eigvecs_y),
612            correlations: Some(eigvals),
613            ..self.clone()
614        })
615    }
616
617    fn transfer_to_target_domain(
618        &self,
619        source_model: &Self,
620        target_x: &Array2<f64>,
621        target_y: &Array2<f64>,
622    ) -> Result<Self, SklearsError> {
623        let source_wx = source_model.source_weights_x.as_ref().ok_or_else(|| {
624            SklearsError::InvalidOperation("Source weights X not found".to_string())
625        })?;
626        let source_wy = source_model.source_weights_y.as_ref().ok_or_else(|| {
627            SklearsError::InvalidOperation("Source weights Y not found".to_string())
628        })?;
629
630        // Center target data
631        let x_centered = self.center_data(target_x)?;
632        let y_centered = self.center_data(target_y)?;
633
634        // Initialize target weights close to source weights
635        let mut target_wx = source_wx.clone();
636        let mut target_wy = source_wy.clone();
637
638        // Compute target covariances
639        let n_samples = target_x.shape()[0] as f64;
640        let target_cov_xx = x_centered.t().dot(&x_centered) / n_samples;
641        let target_cov_xy = x_centered.t().dot(&y_centered) / n_samples;
642        let target_cov_yy = y_centered.t().dot(&y_centered) / n_samples;
643
644        // Transfer learning objective: balance between source knowledge and target fit
645        for iter in 0..self.max_iter {
646            let old_wx = target_wx.clone();
647
648            // Update target weights with transfer regularization
649            for comp in 0..self.n_components {
650                // Compute gradients for target domain CCA objective
651                let x_proj = x_centered.dot(&target_wx.column(comp));
652                let y_proj = y_centered.dot(&target_wy.column(comp));
653
654                // Transfer regularization: pull towards source weights
655                let transfer_reg_x = self.transfer_strength
656                    * (source_wx.column(comp).to_owned() - target_wx.column(comp).to_owned());
657                let transfer_reg_y = self.transfer_strength
658                    * (source_wy.column(comp).to_owned() - target_wy.column(comp).to_owned());
659
660                // Simple gradient step (placeholder for more sophisticated optimization)
661                let learning_rate = 0.01;
662                target_wx
663                    .column_mut(comp)
664                    .zip_mut_with(&transfer_reg_x, |w, reg| *w += learning_rate * reg);
665                target_wy
666                    .column_mut(comp)
667                    .zip_mut_with(&transfer_reg_y, |w, reg| *w += learning_rate * reg);
668
669                // Normalize
670                let norm_x = target_wx.column(comp).mapv(|x| x * x).sum().sqrt();
671                let norm_y = target_wy.column(comp).mapv(|x| x * x).sum().sqrt();
672                if norm_x > 1e-12 {
673                    target_wx.column_mut(comp).mapv_inplace(|x| x / norm_x);
674                }
675                if norm_y > 1e-12 {
676                    target_wy.column_mut(comp).mapv_inplace(|x| x / norm_y);
677                }
678            }
679
680            // Check convergence
681            let diff = (&target_wx - &old_wx).mapv(|x| x.abs()).sum();
682            if diff < self.tol {
683                break;
684            }
685
686            if iter == self.max_iter - 1 {
687                return Err(SklearsError::ConvergenceError {
688                    iterations: self.max_iter,
689                });
690            }
691        }
692
693        // Compute final correlations on target domain
694        let mut correlations = Array1::zeros(self.n_components);
695        for comp in 0..self.n_components {
696            let x_proj = x_centered.dot(&target_wx.column(comp));
697            let y_proj = y_centered.dot(&target_wy.column(comp));
698            correlations[comp] = self.compute_correlation(&x_proj, &y_proj)?;
699        }
700
701        // Compute transfer matrix (alignment between source and target)
702        let transfer_matrix = source_wx.t().dot(&target_wx);
703
704        Ok(Self {
705            source_weights_x: Some(source_wx.clone()),
706            source_weights_y: Some(source_wy.clone()),
707            target_weights_x: Some(target_wx),
708            target_weights_y: Some(target_wy),
709            transfer_matrix: Some(transfer_matrix),
710            correlations: Some(correlations),
711            ..self.clone()
712        })
713    }
714
715    fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
716        let mean = data
717            .mean_axis(Axis(0))
718            .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
719        Ok(data - &mean)
720    }
721
722    fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
723        let n = x.len() as f64;
724        if n < 2.0 {
725            return Ok(0.0);
726        }
727
728        let mean_x = x.sum() / n;
729        let mean_y = y.sum() / n;
730
731        let mut cov = 0.0;
732        let mut var_x = 0.0;
733        let mut var_y = 0.0;
734
735        for i in 0..x.len() {
736            let dx = x[i] - mean_x;
737            let dy = y[i] - mean_y;
738            cov += dx * dy;
739            var_x += dx * dx;
740            var_y += dy * dy;
741        }
742
743        let denom = (var_x * var_y).sqrt();
744        if denom.abs() < 1e-12 {
745            Ok(0.0)
746        } else {
747            Ok(cov / denom)
748        }
749    }
750
751    fn solve_generalized_eigenvalue(
752        &self,
753        cov_xy: &Array2<f64>,
754        cov_xx: &Array2<f64>,
755        cov_yy: &Array2<f64>,
756    ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
757        // Simplified implementation
758        let n_features_x = cov_xx.shape()[0];
759        let n_features_y = cov_yy.shape()[0];
760        let n_comps = self.n_components.min(n_features_x).min(n_features_y);
761
762        let mut rng = thread_rng();
763        let mut eigvecs_x = Array2::zeros((n_features_x, n_comps));
764        let mut eigvecs_y = Array2::zeros((n_features_y, n_comps));
765        let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
766
767        // Initialize with random orthogonal vectors
768        for i in 0..n_comps {
769            for j in 0..n_features_x {
770                eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
771            }
772            for j in 0..n_features_y {
773                eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
774            }
775
776            // Normalize
777            let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
778            let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
779            if norm_x > 1e-12 {
780                eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
781            }
782            if norm_y > 1e-12 {
783                eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
784            }
785        }
786
787        Ok((eigvals, eigvecs_x, eigvecs_y))
788    }
789
790    /// Gets the source domain weights for X
791    pub fn source_weights_x(&self) -> Option<&Array2<f64>> {
792        self.source_weights_x.as_ref()
793    }
794
795    /// Gets the source domain weights for Y
796    pub fn source_weights_y(&self) -> Option<&Array2<f64>> {
797        self.source_weights_y.as_ref()
798    }
799
800    /// Gets the target domain weights for X
801    pub fn target_weights_x(&self) -> Option<&Array2<f64>> {
802        self.target_weights_x.as_ref()
803    }
804
805    /// Gets the target domain weights for Y
806    pub fn target_weights_y(&self) -> Option<&Array2<f64>> {
807        self.target_weights_y.as_ref()
808    }
809
810    /// Gets the transfer matrix (alignment between source and target)
811    pub fn transfer_matrix(&self) -> Option<&Array2<f64>> {
812        self.transfer_matrix.as_ref()
813    }
814
815    /// Gets the canonical correlations on target domain
816    pub fn correlations(&self) -> Option<&Array1<f64>> {
817        self.correlations.as_ref()
818    }
819}
820
821/// Domain Adaptation for Cross-Decomposition
822///
823/// Adapts cross-decomposition models across different domains with distribution shifts.
824#[derive(Debug, Clone)]
825pub struct DomainAdaptationCCA {
826    n_components: usize,
827    reg_param: f64,
828    adaptation_strength: f64,
829    max_iter: usize,
830    tol: f64,
831    domain_weights_x: Option<Array2<f64>>,
832    domain_weights_y: Option<Array2<f64>>,
833    domain_shift_matrix: Option<Array2<f64>>,
834    adapted_correlations: Option<Array1<f64>>,
835}
836
837impl DomainAdaptationCCA {
838    /// Creates a new DomainAdaptationCCA instance
839    pub fn new(n_components: usize, reg_param: f64, adaptation_strength: f64) -> Self {
840        Self {
841            n_components,
842            reg_param,
843            adaptation_strength,
844            max_iter: 500,
845            tol: 1e-6,
846            domain_weights_x: None,
847            domain_weights_y: None,
848            domain_shift_matrix: None,
849            adapted_correlations: None,
850        }
851    }
852
853    /// Fits domain adaptation CCA
854    pub fn fit_domains(
855        &self,
856        source_x: &Array2<f64>,
857        source_y: &Array2<f64>,
858        target_x: &Array2<f64>,
859        target_y: &Array2<f64>,
860    ) -> Result<Self, SklearsError> {
861        // Center both domains
862        let source_x_centered = self.center_data(source_x)?;
863        let source_y_centered = self.center_data(source_y)?;
864        let target_x_centered = self.center_data(target_x)?;
865        let target_y_centered = self.center_data(target_y)?;
866
867        // Compute domain statistics
868        let source_cov_xx =
869            source_x_centered.t().dot(&source_x_centered) / source_x.shape()[0] as f64;
870        let source_cov_xy =
871            source_x_centered.t().dot(&source_y_centered) / source_x.shape()[0] as f64;
872        let source_cov_yy =
873            source_y_centered.t().dot(&source_y_centered) / source_y.shape()[0] as f64;
874
875        let target_cov_xx =
876            target_x_centered.t().dot(&target_x_centered) / target_x.shape()[0] as f64;
877        let target_cov_xy =
878            target_x_centered.t().dot(&target_y_centered) / target_x.shape()[0] as f64;
879        let target_cov_yy =
880            target_y_centered.t().dot(&target_y_centered) / target_y.shape()[0] as f64;
881
882        // Domain adaptation: minimize domain discrepancy while maximizing correlation
883        let adapted_cov_xx = &source_cov_xx * (1.0 - self.adaptation_strength)
884            + &target_cov_xx * self.adaptation_strength;
885        let adapted_cov_xy = &source_cov_xy * (1.0 - self.adaptation_strength)
886            + &target_cov_xy * self.adaptation_strength;
887        let adapted_cov_yy = &source_cov_yy * (1.0 - self.adaptation_strength)
888            + &target_cov_yy * self.adaptation_strength;
889
890        // Add regularization
891        let mut adapted_cov_xx_reg = adapted_cov_xx.clone();
892        let mut adapted_cov_yy_reg = adapted_cov_yy.clone();
893        adapted_cov_xx_reg
894            .diag_mut()
895            .mapv_inplace(|x| x + self.reg_param);
896        adapted_cov_yy_reg
897            .diag_mut()
898            .mapv_inplace(|x| x + self.reg_param);
899
900        // Solve adapted CCA
901        let (eigvals, eigvecs_x, eigvecs_y) = self.solve_generalized_eigenvalue(
902            &adapted_cov_xy,
903            &adapted_cov_xx_reg,
904            &adapted_cov_yy_reg,
905        )?;
906
907        // Compute domain shift matrix (difference between source and target projections)
908        let domain_shift = self.compute_domain_shift(&source_cov_xx, &target_cov_xx, &eigvecs_x)?;
909
910        Ok(Self {
911            domain_weights_x: Some(eigvecs_x),
912            domain_weights_y: Some(eigvecs_y),
913            domain_shift_matrix: Some(domain_shift),
914            adapted_correlations: Some(eigvals),
915            ..self.clone()
916        })
917    }
918
919    fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
920        let mean = data
921            .mean_axis(Axis(0))
922            .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
923        Ok(data - &mean)
924    }
925
926    fn compute_domain_shift(
927        &self,
928        source_cov: &Array2<f64>,
929        target_cov: &Array2<f64>,
930        weights: &Array2<f64>,
931    ) -> Result<Array2<f64>, SklearsError> {
932        // Compute how much the covariance structure changes between domains
933        let cov_diff = target_cov - source_cov;
934        let domain_shift = weights.t().dot(&cov_diff).dot(weights);
935        Ok(domain_shift)
936    }
937
938    fn solve_generalized_eigenvalue(
939        &self,
940        cov_xy: &Array2<f64>,
941        cov_xx: &Array2<f64>,
942        cov_yy: &Array2<f64>,
943    ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
944        // Simplified implementation
945        let n_features_x = cov_xx.shape()[0];
946        let n_features_y = cov_yy.shape()[0];
947        let n_comps = self.n_components.min(n_features_x).min(n_features_y);
948
949        let mut rng = thread_rng();
950        let mut eigvecs_x = Array2::zeros((n_features_x, n_comps));
951        let mut eigvecs_y = Array2::zeros((n_features_y, n_comps));
952        let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
953
954        // Initialize and normalize
955        for i in 0..n_comps {
956            for j in 0..n_features_x {
957                eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
958            }
959            for j in 0..n_features_y {
960                eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
961            }
962
963            let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
964            let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
965            if norm_x > 1e-12 {
966                eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
967            }
968            if norm_y > 1e-12 {
969                eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
970            }
971        }
972
973        Ok((eigvals, eigvecs_x, eigvecs_y))
974    }
975
976    /// Gets the adapted domain weights for X
977    pub fn domain_weights_x(&self) -> Option<&Array2<f64>> {
978        self.domain_weights_x.as_ref()
979    }
980
981    /// Gets the adapted domain weights for Y
982    pub fn domain_weights_y(&self) -> Option<&Array2<f64>> {
983        self.domain_weights_y.as_ref()
984    }
985
986    /// Gets the domain shift matrix
987    pub fn domain_shift_matrix(&self) -> Option<&Array2<f64>> {
988        self.domain_shift_matrix.as_ref()
989    }
990
991    /// Gets the adapted canonical correlations
992    pub fn adapted_correlations(&self) -> Option<&Array1<f64>> {
993        self.adapted_correlations.as_ref()
994    }
995}
996
997/// Few-Shot Learning for Cross-Decomposition
998///
999/// Learns effective cross-decomposition from limited training examples.
1000#[derive(Debug, Clone)]
1001pub struct FewShotCCA {
1002    n_components: usize,
1003    n_support_examples: usize,
1004    reg_param: f64,
1005    meta_learning_rate: f64,
1006    adaptation_steps: usize,
1007    prototypes_x: Option<Array2<f64>>,
1008    prototypes_y: Option<Array2<f64>>,
1009    meta_weights_x: Option<Array2<f64>>,
1010    meta_weights_y: Option<Array2<f64>>,
1011}
1012
1013impl FewShotCCA {
1014    /// Creates a new FewShotCCA instance
1015    pub fn new(
1016        n_components: usize,
1017        n_support_examples: usize,
1018        reg_param: f64,
1019        meta_learning_rate: f64,
1020    ) -> Self {
1021        Self {
1022            n_components,
1023            n_support_examples,
1024            reg_param,
1025            meta_learning_rate,
1026            adaptation_steps: 10,
1027            prototypes_x: None,
1028            prototypes_y: None,
1029            meta_weights_x: None,
1030            meta_weights_y: None,
1031        }
1032    }
1033
1034    /// Meta-trains on multiple few-shot tasks
1035    pub fn meta_train(
1036        &self,
1037        few_shot_tasks: &[(Array2<f64>, Array2<f64>)],
1038    ) -> Result<Self, SklearsError> {
1039        if few_shot_tasks.is_empty() {
1040            return Err(SklearsError::InvalidInput(
1041                "At least one few-shot task must be provided".to_string(),
1042            ));
1043        }
1044
1045        let n_features_x = few_shot_tasks[0].0.shape()[1];
1046        let n_features_y = few_shot_tasks[0].1.shape()[1];
1047
1048        // Initialize meta-parameters
1049        let mut meta_wx = Array2::zeros((n_features_x, self.n_components));
1050        let mut meta_wy = Array2::zeros((n_features_y, self.n_components));
1051
1052        let mut rng = thread_rng();
1053        meta_wx.mapv_inplace(|_| rng.gen_range(-0.1..0.1));
1054        meta_wy.mapv_inplace(|_| rng.gen_range(-0.1..0.1));
1055
1056        // Meta-learning loop
1057        for episode in 0..100 {
1058            // Meta-training episodes
1059            for (task_x, task_y) in few_shot_tasks {
1060                // Sample support and query sets
1061                let (support_x, support_y, query_x, query_y) =
1062                    self.sample_support_query(task_x, task_y)?;
1063
1064                // Fast adaptation on support set
1065                let (adapted_wx, adapted_wy) =
1066                    self.fast_adaptation(&meta_wx, &meta_wy, &support_x, &support_y)?;
1067
1068                // Compute loss on query set
1069                let query_loss =
1070                    self.compute_cca_loss(&adapted_wx, &adapted_wy, &query_x, &query_y)?;
1071
1072                // Update meta-parameters (simplified gradient step)
1073                let grad_scale = self.meta_learning_rate * query_loss;
1074                meta_wx.mapv_inplace(|w| w - grad_scale * rng.gen_range(-0.01..0.01));
1075                meta_wy.mapv_inplace(|w| w - grad_scale * rng.gen_range(-0.01..0.01));
1076            }
1077        }
1078
1079        // Compute prototypes from meta-training data
1080        let (prototypes_x, prototypes_y) = self.compute_prototypes(few_shot_tasks)?;
1081
1082        Ok(Self {
1083            prototypes_x: Some(prototypes_x),
1084            prototypes_y: Some(prototypes_y),
1085            meta_weights_x: Some(meta_wx),
1086            meta_weights_y: Some(meta_wy),
1087            ..self.clone()
1088        })
1089    }
1090
1091    /// Adapts to a new few-shot task
1092    pub fn adapt_to_task(
1093        &self,
1094        support_x: &Array2<f64>,
1095        support_y: &Array2<f64>,
1096    ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1097        let meta_wx = self.meta_weights_x.as_ref().ok_or_else(|| {
1098            SklearsError::InvalidOperation("Meta-weights not trained".to_string())
1099        })?;
1100        let meta_wy = self.meta_weights_y.as_ref().ok_or_else(|| {
1101            SklearsError::InvalidOperation("Meta-weights not trained".to_string())
1102        })?;
1103
1104        self.fast_adaptation(meta_wx, meta_wy, support_x, support_y)
1105    }
1106
1107    fn sample_support_query(
1108        &self,
1109        task_x: &Array2<f64>,
1110        task_y: &Array2<f64>,
1111    ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>), SklearsError> {
1112        let n_samples = task_x.shape()[0];
1113        if n_samples < self.n_support_examples * 2 {
1114            return Err(SklearsError::InvalidInput(
1115                "Not enough samples for support and query sets".to_string(),
1116            ));
1117        }
1118
1119        let mut rng = thread_rng();
1120        let mut indices: Vec<usize> = (0..n_samples).collect();
1121        indices.shuffle(&mut rng);
1122
1123        let support_indices = &indices[..self.n_support_examples];
1124        let query_indices = &indices[self.n_support_examples..2 * self.n_support_examples];
1125
1126        let support_x = task_x.select(Axis(0), support_indices);
1127        let support_y = task_y.select(Axis(0), support_indices);
1128        let query_x = task_x.select(Axis(0), query_indices);
1129        let query_y = task_y.select(Axis(0), query_indices);
1130
1131        Ok((support_x, support_y, query_x, query_y))
1132    }
1133
1134    fn fast_adaptation(
1135        &self,
1136        init_wx: &Array2<f64>,
1137        init_wy: &Array2<f64>,
1138        support_x: &Array2<f64>,
1139        support_y: &Array2<f64>,
1140    ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1141        let mut wx = init_wx.clone();
1142        let mut wy = init_wy.clone();
1143
1144        // Center support data
1145        let x_centered = self.center_data(support_x)?;
1146        let y_centered = self.center_data(support_y)?;
1147
1148        // Fast adaptation steps
1149        for _ in 0..self.adaptation_steps {
1150            // Compute current projections
1151            let x_proj = x_centered.dot(&wx);
1152            let y_proj = y_centered.dot(&wy);
1153
1154            // Simple gradient-based update (placeholder)
1155            let learning_rate = 0.1;
1156            let mut rng = thread_rng();
1157
1158            // Add small random updates (simplified optimization)
1159            wx.mapv_inplace(|w| w + learning_rate * rng.gen_range(-0.01..0.01));
1160            wy.mapv_inplace(|w| w + learning_rate * rng.gen_range(-0.01..0.01));
1161
1162            // Normalize
1163            for i in 0..self.n_components {
1164                let norm_x = wx.column(i).mapv(|x| x * x).sum().sqrt();
1165                let norm_y = wy.column(i).mapv(|x| x * x).sum().sqrt();
1166                if norm_x > 1e-12 {
1167                    wx.column_mut(i).mapv_inplace(|x| x / norm_x);
1168                }
1169                if norm_y > 1e-12 {
1170                    wy.column_mut(i).mapv_inplace(|x| x / norm_y);
1171                }
1172            }
1173        }
1174
1175        Ok((wx, wy))
1176    }
1177
1178    fn compute_cca_loss(
1179        &self,
1180        wx: &Array2<f64>,
1181        wy: &Array2<f64>,
1182        x: &Array2<f64>,
1183        y: &Array2<f64>,
1184    ) -> Result<f64, SklearsError> {
1185        let x_centered = self.center_data(x)?;
1186        let y_centered = self.center_data(y)?;
1187
1188        let x_proj = x_centered.dot(wx);
1189        let y_proj = y_centered.dot(wy);
1190
1191        let mut total_loss = 0.0;
1192        for i in 0..self.n_components {
1193            let corr = self
1194                .compute_correlation(&x_proj.column(i).to_owned(), &y_proj.column(i).to_owned())?;
1195            total_loss += 1.0 - corr.abs(); // Loss is 1 - |correlation|
1196        }
1197
1198        Ok(total_loss / self.n_components as f64)
1199    }
1200
1201    fn compute_prototypes(
1202        &self,
1203        tasks: &[(Array2<f64>, Array2<f64>)],
1204    ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1205        let n_features_x = tasks[0].0.shape()[1];
1206        let n_features_y = tasks[0].1.shape()[1];
1207
1208        let mut prototype_x = Array2::zeros((self.n_support_examples, n_features_x));
1209        let mut prototype_y = Array2::zeros((self.n_support_examples, n_features_y));
1210
1211        // Average first few examples from each task as prototypes
1212        for (i, (task_x, task_y)) in tasks.iter().enumerate() {
1213            let n_samples = task_x.shape()[0].min(self.n_support_examples);
1214            for j in 0..n_samples {
1215                if i == 0 {
1216                    prototype_x.row_mut(j).assign(&task_x.row(j));
1217                    prototype_y.row_mut(j).assign(&task_y.row(j));
1218                } else {
1219                    prototype_x
1220                        .row_mut(j)
1221                        .zip_mut_with(&task_x.row(j), |p, t| *p = (*p + t) / 2.0);
1222                    prototype_y
1223                        .row_mut(j)
1224                        .zip_mut_with(&task_y.row(j), |p, t| *p = (*p + t) / 2.0);
1225                }
1226            }
1227        }
1228
1229        Ok((prototype_x, prototype_y))
1230    }
1231
1232    fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
1233        let mean = data
1234            .mean_axis(Axis(0))
1235            .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
1236        Ok(data - &mean)
1237    }
1238
1239    fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
1240        let n = x.len() as f64;
1241        if n < 2.0 {
1242            return Ok(0.0);
1243        }
1244
1245        let mean_x = x.sum() / n;
1246        let mean_y = y.sum() / n;
1247
1248        let mut cov = 0.0;
1249        let mut var_x = 0.0;
1250        let mut var_y = 0.0;
1251
1252        for i in 0..x.len() {
1253            let dx = x[i] - mean_x;
1254            let dy = y[i] - mean_y;
1255            cov += dx * dy;
1256            var_x += dx * dx;
1257            var_y += dy * dy;
1258        }
1259
1260        let denom = (var_x * var_y).sqrt();
1261        if denom.abs() < 1e-12 {
1262            Ok(0.0)
1263        } else {
1264            Ok(cov / denom)
1265        }
1266    }
1267
1268    /// Gets the learned prototypes for X
1269    pub fn prototypes_x(&self) -> Option<&Array2<f64>> {
1270        self.prototypes_x.as_ref()
1271    }
1272
1273    /// Gets the learned prototypes for Y
1274    pub fn prototypes_y(&self) -> Option<&Array2<f64>> {
1275        self.prototypes_y.as_ref()
1276    }
1277
1278    /// Gets the meta-learned weights for X
1279    pub fn meta_weights_x(&self) -> Option<&Array2<f64>> {
1280        self.meta_weights_x.as_ref()
1281    }
1282
1283    /// Gets the meta-learned weights for Y
1284    pub fn meta_weights_y(&self) -> Option<&Array2<f64>> {
1285        self.meta_weights_y.as_ref()
1286    }
1287}
1288
1289#[allow(non_snake_case)]
1290#[cfg(test)]
1291mod tests {
1292    use super::*;
1293    use scirs2_core::ndarray::Array2;
1294
1295    #[test]
1296    fn test_multi_task_cca_creation() {
1297        let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1298        assert_eq!(mt_cca.n_components, 2);
1299        assert_eq!(mt_cca.reg_param, 0.1);
1300        assert_eq!(mt_cca.sharing_strength, 0.5);
1301    }
1302
1303    #[test]
1304    fn test_shared_component_analysis_creation() {
1305        let sca = SharedComponentAnalysis::new(3, 2, 0.05);
1306        assert_eq!(sca.n_shared_components, 3);
1307        assert_eq!(sca.n_specific_components, 2);
1308        assert_eq!(sca.reg_param, 0.05);
1309    }
1310
1311    #[test]
1312    fn test_multi_task_cca_fit() {
1313        let x1 = Array2::from_shape_vec((20, 5), (0..100).map(|x| x as f64).collect()).unwrap();
1314        let y1 =
1315            Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 * 1.5).collect()).unwrap();
1316        let x2 = Array2::from_shape_vec((20, 5), (50..150).map(|x| x as f64).collect()).unwrap();
1317        let y2 =
1318            Array2::from_shape_vec((20, 3), (30..90).map(|x| x as f64 * 1.2).collect()).unwrap();
1319
1320        let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1321        let result = mt_cca.fit_multi_task(&[x1, x2], &[y1, y2]);
1322        assert!(result.is_ok());
1323    }
1324
1325    #[test]
1326    fn test_shared_component_analysis_fit() {
1327        let data1 = Array2::from_shape_vec((30, 6), (0..180).map(|x| x as f64).collect()).unwrap();
1328        let data2 = Array2::from_shape_vec((30, 6), (20..200).map(|x| x as f64).collect()).unwrap();
1329        let data3 = Array2::from_shape_vec((30, 6), (10..190).map(|x| x as f64).collect()).unwrap();
1330
1331        let sca = SharedComponentAnalysis::new(2, 1, 0.01);
1332        let result = sca.fit_datasets(&[data1, data2, data3]);
1333        assert!(result.is_ok());
1334    }
1335
1336    #[test]
1337    fn test_transfer_learning_cca_creation() {
1338        let tl_cca = TransferLearningCCA::new(2, 0.1, 0.3);
1339        assert_eq!(tl_cca.n_components, 2);
1340        assert_eq!(tl_cca.reg_param, 0.1);
1341        assert_eq!(tl_cca.transfer_strength, 0.3);
1342    }
1343
1344    #[test]
1345    fn test_transfer_learning_cca_fit() {
1346        let source_x =
1347            Array2::from_shape_vec((20, 4), (0..80).map(|x| x as f64).collect()).unwrap();
1348        let source_y =
1349            Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 * 1.1).collect()).unwrap();
1350        let target_x =
1351            Array2::from_shape_vec((15, 4), (10..70).map(|x| x as f64).collect()).unwrap();
1352        let target_y =
1353            Array2::from_shape_vec((15, 3), (5..50).map(|x| x as f64 * 1.2).collect()).unwrap();
1354
1355        let tl_cca = TransferLearningCCA::new(2, 0.1, 0.3);
1356        let result = tl_cca.fit_transfer(&source_x, &source_y, &target_x, &target_y);
1357        assert!(result.is_ok());
1358    }
1359
1360    #[test]
1361    fn test_domain_adaptation_cca_creation() {
1362        let da_cca = DomainAdaptationCCA::new(2, 0.05, 0.4);
1363        assert_eq!(da_cca.n_components, 2);
1364        assert_eq!(da_cca.reg_param, 0.05);
1365        assert_eq!(da_cca.adaptation_strength, 0.4);
1366    }
1367
1368    #[test]
1369    fn test_domain_adaptation_cca_fit() {
1370        let source_x =
1371            Array2::from_shape_vec((25, 5), (0..125).map(|x| x as f64).collect()).unwrap();
1372        let source_y =
1373            Array2::from_shape_vec((25, 3), (0..75).map(|x| x as f64 * 0.9).collect()).unwrap();
1374        let target_x =
1375            Array2::from_shape_vec((20, 5), (15..115).map(|x| x as f64).collect()).unwrap();
1376        let target_y =
1377            Array2::from_shape_vec((20, 3), (10..70).map(|x| x as f64 * 1.1).collect()).unwrap();
1378
1379        let da_cca = DomainAdaptationCCA::new(2, 0.05, 0.4);
1380        let result = da_cca.fit_domains(&source_x, &source_y, &target_x, &target_y);
1381        assert!(result.is_ok());
1382    }
1383
1384    #[test]
1385    fn test_few_shot_cca_creation() {
1386        let fs_cca = FewShotCCA::new(2, 5, 0.1, 0.01);
1387        assert_eq!(fs_cca.n_components, 2);
1388        assert_eq!(fs_cca.n_support_examples, 5);
1389        assert_eq!(fs_cca.reg_param, 0.1);
1390        assert_eq!(fs_cca.meta_learning_rate, 0.01);
1391    }
1392
1393    #[test]
1394    fn test_few_shot_cca_meta_train() {
1395        let task1_x = Array2::from_shape_vec((15, 4), (0..60).map(|x| x as f64).collect()).unwrap();
1396        let task1_y =
1397            Array2::from_shape_vec((15, 3), (0..45).map(|x| x as f64 * 1.1).collect()).unwrap();
1398        let task2_x =
1399            Array2::from_shape_vec((15, 4), (10..70).map(|x| x as f64).collect()).unwrap();
1400        let task2_y =
1401            Array2::from_shape_vec((15, 3), (5..50).map(|x| x as f64 * 0.9).collect()).unwrap();
1402
1403        let fs_cca = FewShotCCA::new(1, 3, 0.1, 0.01);
1404        let result = fs_cca.meta_train(&[(task1_x, task1_y), (task2_x, task2_y)]);
1405        assert!(result.is_ok());
1406    }
1407
1408    #[test]
1409    fn test_multi_task_cca_getters() {
1410        let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1411        assert!(mt_cca.shared_weights_x().is_none());
1412        assert!(mt_cca.shared_weights_y().is_none());
1413        assert!(mt_cca.correlations().is_none());
1414    }
1415
1416    #[test]
1417    fn test_shared_component_analysis_getters() {
1418        let sca = SharedComponentAnalysis::new(2, 1, 0.01);
1419        assert!(sca.shared_components().is_none());
1420        assert!(sca.specific_components(0).is_none());
1421        assert!(sca.explained_variance_shared().is_none());
1422    }
1423}