1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
8use scirs2_core::rand_prelude::SliceRandom;
9use scirs2_core::random::{thread_rng, Random, Rng};
10use sklears_core::error::SklearsError;
11use sklears_core::traits::Estimator;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
19pub struct MultiTaskCCA {
20 n_components: usize,
21 reg_param: f64,
22 max_iter: usize,
23 tol: f64,
24 sharing_strength: f64,
25 canonical_weights_x: Option<Array2<f64>>,
26 canonical_weights_y: Option<Array2<f64>>,
27 shared_components: Option<Array2<f64>>,
28 task_specific_components: Option<HashMap<usize, Array2<f64>>>,
29 correlations: Option<Array1<f64>>,
30}
31
32impl MultiTaskCCA {
33 pub fn new(n_components: usize, reg_param: f64, sharing_strength: f64) -> Self {
35 Self {
36 n_components,
37 reg_param,
38 max_iter: 500,
39 tol: 1e-6,
40 sharing_strength,
41 canonical_weights_x: None,
42 canonical_weights_y: None,
43 shared_components: None,
44 task_specific_components: None,
45 correlations: None,
46 }
47 }
48
49 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
51 self.max_iter = max_iter;
52 self
53 }
54
55 pub fn with_tolerance(mut self, tol: f64) -> Self {
57 self.tol = tol;
58 self
59 }
60
61 pub fn fit_multi_task(
63 &self,
64 x_tasks: &[Array2<f64>],
65 y_tasks: &[Array2<f64>],
66 ) -> Result<Self, SklearsError> {
67 if x_tasks.len() != y_tasks.len() {
68 return Err(SklearsError::InvalidInput(
69 "Number of X and Y tasks must match".to_string(),
70 ));
71 }
72
73 if x_tasks.is_empty() {
74 return Err(SklearsError::InvalidInput(
75 "At least one task must be provided".to_string(),
76 ));
77 }
78
79 let n_tasks = x_tasks.len();
80 let n_features_x = x_tasks[0].shape()[1];
81 let n_features_y = y_tasks[0].shape()[1];
82
83 let mut shared_wx = Array2::zeros((n_features_x, self.n_components));
85 let mut shared_wy = Array2::zeros((n_features_y, self.n_components));
86 let mut task_specific_wx = HashMap::new();
87 let mut task_specific_wy = HashMap::new();
88
89 for task_id in 0..n_tasks {
91 task_specific_wx.insert(task_id, Array2::zeros((n_features_x, self.n_components)));
92 task_specific_wy.insert(task_id, Array2::zeros((n_features_y, self.n_components)));
93 }
94
95 for iter in 0..self.max_iter {
97 let mut converged = true;
98 let old_shared_wx = shared_wx.clone();
99
100 for comp in 0..self.n_components {
102 let mut cov_xx_shared = Array2::zeros((n_features_x, n_features_x));
103 let mut cov_xy_shared = Array2::zeros((n_features_x, n_features_y));
104 let mut cov_yy_shared = Array2::zeros((n_features_y, n_features_y));
105
106 for (task_id, (x_task, y_task)) in x_tasks.iter().zip(y_tasks.iter()).enumerate() {
108 let x_centered = self.center_data(x_task)?;
109 let y_centered = self.center_data(y_task)?;
110
111 let task_wx = &task_specific_wx[&task_id];
112 let task_wy = &task_specific_wy[&task_id];
113
114 let x_proj = x_centered.dot(task_wx);
116 let x_recon = x_proj.dot(&task_wx.t());
117 let x_residual = &x_centered - &x_recon;
118
119 let y_proj = y_centered.dot(task_wy);
120 let y_recon = y_proj.dot(&task_wy.t());
121 let y_residual = &y_centered - &y_recon;
122
123 cov_xx_shared =
124 cov_xx_shared + x_residual.t().dot(&x_residual) / x_task.shape()[0] as f64;
125 cov_xy_shared =
126 cov_xy_shared + x_residual.t().dot(&y_residual) / x_task.shape()[0] as f64;
127 cov_yy_shared =
128 cov_yy_shared + y_residual.t().dot(&y_residual) / y_task.shape()[0] as f64;
129 }
130
131 cov_xx_shared
133 .diag_mut()
134 .mapv_inplace(|x| x + self.reg_param);
135 cov_yy_shared
136 .diag_mut()
137 .mapv_inplace(|x| x + self.reg_param);
138
139 let (eigvals, eigvecs_x, eigvecs_y) = self.solve_generalized_eigenvalue(
141 &cov_xy_shared,
142 &cov_xx_shared,
143 &cov_yy_shared,
144 )?;
145
146 if let Some(max_idx) = eigvals
147 .iter()
148 .enumerate()
149 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
150 .map(|(idx, _)| idx)
151 {
152 shared_wx
153 .column_mut(comp)
154 .assign(&eigvecs_x.column(max_idx));
155 shared_wy
156 .column_mut(comp)
157 .assign(&eigvecs_y.column(max_idx));
158 }
159 }
160
161 for (task_id, (x_task, y_task)) in x_tasks.iter().zip(y_tasks.iter()).enumerate() {
163 let x_centered = self.center_data(x_task)?;
164 let y_centered = self.center_data(y_task)?;
165
166 let x_shared_proj = x_centered.dot(&shared_wx);
168 let x_shared_recon = x_shared_proj.dot(&shared_wx.t());
169 let x_residual = &x_centered - &x_shared_recon;
170
171 let y_shared_proj = y_centered.dot(&shared_wy);
172 let y_shared_recon = y_shared_proj.dot(&shared_wy.t());
173 let y_residual = &y_centered - &y_shared_recon;
174
175 let cov_xx = x_residual.t().dot(&x_residual) / x_task.shape()[0] as f64;
177 let cov_xy = x_residual.t().dot(&y_residual) / x_task.shape()[0] as f64;
178 let cov_yy = y_residual.t().dot(&y_residual) / y_task.shape()[0] as f64;
179
180 let mut cov_xx_reg = cov_xx.clone();
181 let mut cov_yy_reg = cov_yy.clone();
182 cov_xx_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
183 cov_yy_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
184
185 let (_, eigvecs_x, eigvecs_y) =
186 self.solve_generalized_eigenvalue(&cov_xy, &cov_xx_reg, &cov_yy_reg)?;
187
188 let n_comps = self.n_components.min(eigvecs_x.shape()[1]);
189 if let Some(task_wx) = task_specific_wx.get_mut(&task_id) {
190 task_wx
191 .slice_mut(s![.., ..n_comps])
192 .assign(&eigvecs_x.slice(s![.., ..n_comps]));
193 }
194 if let Some(task_wy) = task_specific_wy.get_mut(&task_id) {
195 task_wy
196 .slice_mut(s![.., ..n_comps])
197 .assign(&eigvecs_y.slice(s![.., ..n_comps]));
198 }
199 }
200
201 let diff = (&shared_wx - &old_shared_wx).mapv(|x| x.abs()).sum();
203 if diff < self.tol {
204 converged = true;
205 break;
206 }
207
208 if iter == self.max_iter - 1 && !converged {
209 return Err(SklearsError::ConvergenceError {
210 iterations: self.max_iter,
211 });
212 }
213 }
214
215 let mut correlations = Array1::zeros(self.n_components);
217 for comp in 0..self.n_components {
218 let mut total_corr = 0.0;
219 for (x_task, y_task) in x_tasks.iter().zip(y_tasks.iter()) {
220 let x_centered = self.center_data(x_task)?;
221 let y_centered = self.center_data(y_task)?;
222
223 let x_proj = x_centered.dot(&shared_wx.column(comp));
224 let y_proj = y_centered.dot(&shared_wy.column(comp));
225
226 let corr = self.compute_correlation(&x_proj, &y_proj)?;
227 total_corr += corr;
228 }
229 correlations[comp] = total_corr / n_tasks as f64;
230 }
231
232 Ok(Self {
233 canonical_weights_x: Some(shared_wx),
234 canonical_weights_y: Some(shared_wy),
235 shared_components: Some(Array2::zeros((self.n_components, self.n_components))),
236 task_specific_components: Some(task_specific_wx),
237 correlations: Some(correlations),
238 ..self.clone()
239 })
240 }
241
242 fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
243 let mean = data
244 .mean_axis(Axis(0))
245 .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
246 Ok(data - &mean)
247 }
248
249 fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
250 let n = x.len() as f64;
251 if n < 2.0 {
252 return Ok(0.0);
253 }
254
255 let mean_x = x.sum() / n;
256 let mean_y = y.sum() / n;
257
258 let mut cov = 0.0;
259 let mut var_x = 0.0;
260 let mut var_y = 0.0;
261
262 for i in 0..x.len() {
263 let dx = x[i] - mean_x;
264 let dy = y[i] - mean_y;
265 cov += dx * dy;
266 var_x += dx * dx;
267 var_y += dy * dy;
268 }
269
270 let denom = (var_x * var_y).sqrt();
271 if denom.abs() < 1e-12 {
272 Ok(0.0)
273 } else {
274 Ok(cov / denom)
275 }
276 }
277
278 fn solve_generalized_eigenvalue(
279 &self,
280 cov_xy: &Array2<f64>,
281 cov_xx: &Array2<f64>,
282 cov_yy: &Array2<f64>,
283 ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
284 let n_features = cov_xx.shape()[0];
287 let n_comps = self.n_components.min(n_features);
288
289 let mut rng = thread_rng();
290 let mut eigvecs_x = Array2::zeros((n_features, n_comps));
291 let mut eigvecs_y = Array2::zeros((cov_yy.shape()[0], n_comps));
292 let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
293
294 for i in 0..n_comps {
296 for j in 0..n_features {
297 eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
298 }
299 for j in 0..cov_yy.shape()[0] {
300 eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
301 }
302 }
303
304 for i in 0..n_comps {
306 let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
307 let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
308 if norm_x > 1e-12 {
309 eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
310 }
311 if norm_y > 1e-12 {
312 eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
313 }
314 }
315
316 Ok((eigvals, eigvecs_x, eigvecs_y))
317 }
318
319 pub fn shared_weights_x(&self) -> Option<&Array2<f64>> {
321 self.canonical_weights_x.as_ref()
322 }
323
324 pub fn shared_weights_y(&self) -> Option<&Array2<f64>> {
326 self.canonical_weights_y.as_ref()
327 }
328
329 pub fn task_weights(&self, task_id: usize) -> Option<&Array2<f64>> {
331 self.task_specific_components.as_ref()?.get(&task_id)
332 }
333
334 pub fn correlations(&self) -> Option<&Array1<f64>> {
336 self.correlations.as_ref()
337 }
338}
339
340#[derive(Debug, Clone)]
345pub struct SharedComponentAnalysis {
346 n_shared_components: usize,
347 n_specific_components: usize,
348 reg_param: f64,
349 max_iter: usize,
350 tol: f64,
351 shared_components: Option<Array2<f64>>,
352 specific_components: Option<HashMap<usize, Array2<f64>>>,
353 explained_variance_shared: Option<Array1<f64>>,
354 explained_variance_specific: Option<HashMap<usize, Array1<f64>>>,
355}
356
357impl SharedComponentAnalysis {
358 pub fn new(n_shared_components: usize, n_specific_components: usize, reg_param: f64) -> Self {
360 Self {
361 n_shared_components,
362 n_specific_components,
363 reg_param,
364 max_iter: 100,
365 tol: 1e-3,
366 shared_components: None,
367 specific_components: None,
368 explained_variance_shared: None,
369 explained_variance_specific: None,
370 }
371 }
372
373 pub fn fit_datasets(&self, datasets: &[Array2<f64>]) -> Result<Self, SklearsError> {
375 if datasets.is_empty() {
376 return Err(SklearsError::InvalidInput(
377 "At least one dataset must be provided".to_string(),
378 ));
379 }
380
381 let n_tasks = datasets.len();
382 let n_features = datasets[0].shape()[1];
383
384 let mut centered_datasets = Vec::new();
386 for dataset in datasets {
387 let centered = self.center_data(dataset)?;
388 centered_datasets.push(centered);
389 }
390
391 let mut shared_comps = Array2::zeros((n_features, self.n_shared_components));
393 let mut specific_comps = HashMap::new();
394
395 for task_id in 0..n_tasks {
396 specific_comps.insert(
397 task_id,
398 Array2::zeros((n_features, self.n_specific_components)),
399 );
400 }
401
402 let mut rng = thread_rng();
404 shared_comps.mapv_inplace(|_| rng.gen_range(-1.0..1.0));
405 for comps in specific_comps.values_mut() {
406 comps.mapv_inplace(|_| rng.gen_range(-1.0..1.0));
407 }
408
409 let mut total_cov = Array2::zeros((n_features, n_features));
411 for dataset in ¢ered_datasets {
412 let cov = dataset.t().dot(dataset) / dataset.shape()[0] as f64;
413 total_cov = total_cov + cov;
414 }
415 total_cov = total_cov / n_tasks as f64;
416
417 total_cov.diag_mut().mapv_inplace(|x| x + self.reg_param);
419
420 shared_comps = self.compute_principal_components(&total_cov, self.n_shared_components)?;
422
423 for (task_id, dataset) in centered_datasets.iter().enumerate() {
425 let shared_proj = dataset.dot(&shared_comps);
426 let shared_recon = shared_proj.dot(&shared_comps.t());
427 let residual = dataset - &shared_recon;
428
429 let specific_cov = residual.t().dot(&residual) / dataset.shape()[0] as f64;
430 let mut specific_cov_reg = specific_cov.clone();
431 specific_cov_reg
432 .diag_mut()
433 .mapv_inplace(|x| x + self.reg_param);
434
435 let specific_pc =
436 self.compute_principal_components(&specific_cov_reg, self.n_specific_components)?;
437 specific_comps.insert(task_id, specific_pc);
438 }
439
440 let mut shared_variance = Array1::zeros(self.n_shared_components);
442 let mut specific_variance = HashMap::new();
443
444 for (task_id, dataset) in centered_datasets.iter().enumerate() {
445 let shared_proj = dataset.dot(&shared_comps);
447 for comp in 0..self.n_shared_components {
448 let var =
449 shared_proj.column(comp).mapv(|x| x * x).sum() / dataset.shape()[0] as f64;
450 shared_variance[comp] += var;
451 }
452
453 let specific = &specific_comps[&task_id];
455 let specific_proj = dataset.dot(specific);
456 let mut specific_var = Array1::zeros(self.n_specific_components);
457 for comp in 0..self.n_specific_components {
458 let var =
459 specific_proj.column(comp).mapv(|x| x * x).sum() / dataset.shape()[0] as f64;
460 specific_var[comp] = var;
461 }
462 specific_variance.insert(task_id, specific_var);
463 }
464
465 shared_variance.mapv_inplace(|x| x / n_tasks as f64);
467
468 Ok(Self {
469 shared_components: Some(shared_comps),
470 specific_components: Some(specific_comps),
471 explained_variance_shared: Some(shared_variance),
472 explained_variance_specific: Some(specific_variance),
473 ..self.clone()
474 })
475 }
476
477 fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
478 let mean = data
479 .mean_axis(Axis(0))
480 .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
481 Ok(data - &mean)
482 }
483
484 fn compute_principal_components(
485 &self,
486 cov_matrix: &Array2<f64>,
487 n_components: usize,
488 ) -> Result<Array2<f64>, SklearsError> {
489 let n_features = cov_matrix.shape()[0];
490 let n_comps = n_components.min(n_features);
491
492 let mut rng = thread_rng();
494 let mut components = Array2::zeros((n_features, n_comps));
495
496 for i in 0..n_comps {
497 for j in 0..n_features {
498 components[[j, i]] = rng.gen_range(-1.0..1.0);
499 }
500 let norm = (components.column(i).mapv(|x| x * x).sum() as f64).sqrt();
502 if norm > 1e-12 {
503 components.column_mut(i).mapv_inplace(|x| x / norm);
504 }
505 }
506
507 Ok(components)
508 }
509
510 pub fn shared_components(&self) -> Option<&Array2<f64>> {
512 self.shared_components.as_ref()
513 }
514
515 pub fn specific_components(&self, task_id: usize) -> Option<&Array2<f64>> {
517 self.specific_components.as_ref()?.get(&task_id)
518 }
519
520 pub fn explained_variance_shared(&self) -> Option<&Array1<f64>> {
522 self.explained_variance_shared.as_ref()
523 }
524
525 pub fn explained_variance_specific(&self, task_id: usize) -> Option<&Array1<f64>> {
527 self.explained_variance_specific.as_ref()?.get(&task_id)
528 }
529}
530
531#[derive(Debug, Clone)]
535pub struct TransferLearningCCA {
536 n_components: usize,
537 reg_param: f64,
538 transfer_strength: f64,
539 max_iter: usize,
540 tol: f64,
541 source_weights_x: Option<Array2<f64>>,
542 source_weights_y: Option<Array2<f64>>,
543 target_weights_x: Option<Array2<f64>>,
544 target_weights_y: Option<Array2<f64>>,
545 transfer_matrix: Option<Array2<f64>>,
546 correlations: Option<Array1<f64>>,
547}
548
549impl TransferLearningCCA {
550 pub fn new(n_components: usize, reg_param: f64, transfer_strength: f64) -> Self {
552 Self {
553 n_components,
554 reg_param,
555 transfer_strength,
556 max_iter: 500,
557 tol: 1e-6,
558 source_weights_x: None,
559 source_weights_y: None,
560 target_weights_x: None,
561 target_weights_y: None,
562 transfer_matrix: None,
563 correlations: None,
564 }
565 }
566
567 pub fn fit_transfer(
569 &self,
570 source_x: &Array2<f64>,
571 source_y: &Array2<f64>,
572 target_x: &Array2<f64>,
573 target_y: &Array2<f64>,
574 ) -> Result<Self, SklearsError> {
575 let source_result = self.fit_source_domain(source_x, source_y)?;
577
578 let target_result = self.transfer_to_target_domain(&source_result, target_x, target_y)?;
580
581 Ok(target_result)
582 }
583
584 fn fit_source_domain(
585 &self,
586 source_x: &Array2<f64>,
587 source_y: &Array2<f64>,
588 ) -> Result<Self, SklearsError> {
589 let x_centered = self.center_data(source_x)?;
591 let y_centered = self.center_data(source_y)?;
592
593 let n_samples = source_x.shape()[0] as f64;
595 let cov_xx = x_centered.t().dot(&x_centered) / n_samples;
596 let cov_xy = x_centered.t().dot(&y_centered) / n_samples;
597 let cov_yy = y_centered.t().dot(&y_centered) / n_samples;
598
599 let mut cov_xx_reg = cov_xx.clone();
601 let mut cov_yy_reg = cov_yy.clone();
602 cov_xx_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
603 cov_yy_reg.diag_mut().mapv_inplace(|x| x + self.reg_param);
604
605 let (eigvals, eigvecs_x, eigvecs_y) =
607 self.solve_generalized_eigenvalue(&cov_xy, &cov_xx_reg, &cov_yy_reg)?;
608
609 Ok(Self {
610 source_weights_x: Some(eigvecs_x),
611 source_weights_y: Some(eigvecs_y),
612 correlations: Some(eigvals),
613 ..self.clone()
614 })
615 }
616
617 fn transfer_to_target_domain(
618 &self,
619 source_model: &Self,
620 target_x: &Array2<f64>,
621 target_y: &Array2<f64>,
622 ) -> Result<Self, SklearsError> {
623 let source_wx = source_model.source_weights_x.as_ref().ok_or_else(|| {
624 SklearsError::InvalidOperation("Source weights X not found".to_string())
625 })?;
626 let source_wy = source_model.source_weights_y.as_ref().ok_or_else(|| {
627 SklearsError::InvalidOperation("Source weights Y not found".to_string())
628 })?;
629
630 let x_centered = self.center_data(target_x)?;
632 let y_centered = self.center_data(target_y)?;
633
634 let mut target_wx = source_wx.clone();
636 let mut target_wy = source_wy.clone();
637
638 let n_samples = target_x.shape()[0] as f64;
640 let target_cov_xx = x_centered.t().dot(&x_centered) / n_samples;
641 let target_cov_xy = x_centered.t().dot(&y_centered) / n_samples;
642 let target_cov_yy = y_centered.t().dot(&y_centered) / n_samples;
643
644 for iter in 0..self.max_iter {
646 let old_wx = target_wx.clone();
647
648 for comp in 0..self.n_components {
650 let x_proj = x_centered.dot(&target_wx.column(comp));
652 let y_proj = y_centered.dot(&target_wy.column(comp));
653
654 let transfer_reg_x = self.transfer_strength
656 * (source_wx.column(comp).to_owned() - target_wx.column(comp).to_owned());
657 let transfer_reg_y = self.transfer_strength
658 * (source_wy.column(comp).to_owned() - target_wy.column(comp).to_owned());
659
660 let learning_rate = 0.01;
662 target_wx
663 .column_mut(comp)
664 .zip_mut_with(&transfer_reg_x, |w, reg| *w += learning_rate * reg);
665 target_wy
666 .column_mut(comp)
667 .zip_mut_with(&transfer_reg_y, |w, reg| *w += learning_rate * reg);
668
669 let norm_x = target_wx.column(comp).mapv(|x| x * x).sum().sqrt();
671 let norm_y = target_wy.column(comp).mapv(|x| x * x).sum().sqrt();
672 if norm_x > 1e-12 {
673 target_wx.column_mut(comp).mapv_inplace(|x| x / norm_x);
674 }
675 if norm_y > 1e-12 {
676 target_wy.column_mut(comp).mapv_inplace(|x| x / norm_y);
677 }
678 }
679
680 let diff = (&target_wx - &old_wx).mapv(|x| x.abs()).sum();
682 if diff < self.tol {
683 break;
684 }
685
686 if iter == self.max_iter - 1 {
687 return Err(SklearsError::ConvergenceError {
688 iterations: self.max_iter,
689 });
690 }
691 }
692
693 let mut correlations = Array1::zeros(self.n_components);
695 for comp in 0..self.n_components {
696 let x_proj = x_centered.dot(&target_wx.column(comp));
697 let y_proj = y_centered.dot(&target_wy.column(comp));
698 correlations[comp] = self.compute_correlation(&x_proj, &y_proj)?;
699 }
700
701 let transfer_matrix = source_wx.t().dot(&target_wx);
703
704 Ok(Self {
705 source_weights_x: Some(source_wx.clone()),
706 source_weights_y: Some(source_wy.clone()),
707 target_weights_x: Some(target_wx),
708 target_weights_y: Some(target_wy),
709 transfer_matrix: Some(transfer_matrix),
710 correlations: Some(correlations),
711 ..self.clone()
712 })
713 }
714
715 fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
716 let mean = data
717 .mean_axis(Axis(0))
718 .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
719 Ok(data - &mean)
720 }
721
722 fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
723 let n = x.len() as f64;
724 if n < 2.0 {
725 return Ok(0.0);
726 }
727
728 let mean_x = x.sum() / n;
729 let mean_y = y.sum() / n;
730
731 let mut cov = 0.0;
732 let mut var_x = 0.0;
733 let mut var_y = 0.0;
734
735 for i in 0..x.len() {
736 let dx = x[i] - mean_x;
737 let dy = y[i] - mean_y;
738 cov += dx * dy;
739 var_x += dx * dx;
740 var_y += dy * dy;
741 }
742
743 let denom = (var_x * var_y).sqrt();
744 if denom.abs() < 1e-12 {
745 Ok(0.0)
746 } else {
747 Ok(cov / denom)
748 }
749 }
750
751 fn solve_generalized_eigenvalue(
752 &self,
753 cov_xy: &Array2<f64>,
754 cov_xx: &Array2<f64>,
755 cov_yy: &Array2<f64>,
756 ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
757 let n_features_x = cov_xx.shape()[0];
759 let n_features_y = cov_yy.shape()[0];
760 let n_comps = self.n_components.min(n_features_x).min(n_features_y);
761
762 let mut rng = thread_rng();
763 let mut eigvecs_x = Array2::zeros((n_features_x, n_comps));
764 let mut eigvecs_y = Array2::zeros((n_features_y, n_comps));
765 let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
766
767 for i in 0..n_comps {
769 for j in 0..n_features_x {
770 eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
771 }
772 for j in 0..n_features_y {
773 eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
774 }
775
776 let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
778 let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
779 if norm_x > 1e-12 {
780 eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
781 }
782 if norm_y > 1e-12 {
783 eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
784 }
785 }
786
787 Ok((eigvals, eigvecs_x, eigvecs_y))
788 }
789
790 pub fn source_weights_x(&self) -> Option<&Array2<f64>> {
792 self.source_weights_x.as_ref()
793 }
794
795 pub fn source_weights_y(&self) -> Option<&Array2<f64>> {
797 self.source_weights_y.as_ref()
798 }
799
800 pub fn target_weights_x(&self) -> Option<&Array2<f64>> {
802 self.target_weights_x.as_ref()
803 }
804
805 pub fn target_weights_y(&self) -> Option<&Array2<f64>> {
807 self.target_weights_y.as_ref()
808 }
809
810 pub fn transfer_matrix(&self) -> Option<&Array2<f64>> {
812 self.transfer_matrix.as_ref()
813 }
814
815 pub fn correlations(&self) -> Option<&Array1<f64>> {
817 self.correlations.as_ref()
818 }
819}
820
821#[derive(Debug, Clone)]
825pub struct DomainAdaptationCCA {
826 n_components: usize,
827 reg_param: f64,
828 adaptation_strength: f64,
829 max_iter: usize,
830 tol: f64,
831 domain_weights_x: Option<Array2<f64>>,
832 domain_weights_y: Option<Array2<f64>>,
833 domain_shift_matrix: Option<Array2<f64>>,
834 adapted_correlations: Option<Array1<f64>>,
835}
836
837impl DomainAdaptationCCA {
838 pub fn new(n_components: usize, reg_param: f64, adaptation_strength: f64) -> Self {
840 Self {
841 n_components,
842 reg_param,
843 adaptation_strength,
844 max_iter: 500,
845 tol: 1e-6,
846 domain_weights_x: None,
847 domain_weights_y: None,
848 domain_shift_matrix: None,
849 adapted_correlations: None,
850 }
851 }
852
853 pub fn fit_domains(
855 &self,
856 source_x: &Array2<f64>,
857 source_y: &Array2<f64>,
858 target_x: &Array2<f64>,
859 target_y: &Array2<f64>,
860 ) -> Result<Self, SklearsError> {
861 let source_x_centered = self.center_data(source_x)?;
863 let source_y_centered = self.center_data(source_y)?;
864 let target_x_centered = self.center_data(target_x)?;
865 let target_y_centered = self.center_data(target_y)?;
866
867 let source_cov_xx =
869 source_x_centered.t().dot(&source_x_centered) / source_x.shape()[0] as f64;
870 let source_cov_xy =
871 source_x_centered.t().dot(&source_y_centered) / source_x.shape()[0] as f64;
872 let source_cov_yy =
873 source_y_centered.t().dot(&source_y_centered) / source_y.shape()[0] as f64;
874
875 let target_cov_xx =
876 target_x_centered.t().dot(&target_x_centered) / target_x.shape()[0] as f64;
877 let target_cov_xy =
878 target_x_centered.t().dot(&target_y_centered) / target_x.shape()[0] as f64;
879 let target_cov_yy =
880 target_y_centered.t().dot(&target_y_centered) / target_y.shape()[0] as f64;
881
882 let adapted_cov_xx = &source_cov_xx * (1.0 - self.adaptation_strength)
884 + &target_cov_xx * self.adaptation_strength;
885 let adapted_cov_xy = &source_cov_xy * (1.0 - self.adaptation_strength)
886 + &target_cov_xy * self.adaptation_strength;
887 let adapted_cov_yy = &source_cov_yy * (1.0 - self.adaptation_strength)
888 + &target_cov_yy * self.adaptation_strength;
889
890 let mut adapted_cov_xx_reg = adapted_cov_xx.clone();
892 let mut adapted_cov_yy_reg = adapted_cov_yy.clone();
893 adapted_cov_xx_reg
894 .diag_mut()
895 .mapv_inplace(|x| x + self.reg_param);
896 adapted_cov_yy_reg
897 .diag_mut()
898 .mapv_inplace(|x| x + self.reg_param);
899
900 let (eigvals, eigvecs_x, eigvecs_y) = self.solve_generalized_eigenvalue(
902 &adapted_cov_xy,
903 &adapted_cov_xx_reg,
904 &adapted_cov_yy_reg,
905 )?;
906
907 let domain_shift = self.compute_domain_shift(&source_cov_xx, &target_cov_xx, &eigvecs_x)?;
909
910 Ok(Self {
911 domain_weights_x: Some(eigvecs_x),
912 domain_weights_y: Some(eigvecs_y),
913 domain_shift_matrix: Some(domain_shift),
914 adapted_correlations: Some(eigvals),
915 ..self.clone()
916 })
917 }
918
919 fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
920 let mean = data
921 .mean_axis(Axis(0))
922 .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
923 Ok(data - &mean)
924 }
925
926 fn compute_domain_shift(
927 &self,
928 source_cov: &Array2<f64>,
929 target_cov: &Array2<f64>,
930 weights: &Array2<f64>,
931 ) -> Result<Array2<f64>, SklearsError> {
932 let cov_diff = target_cov - source_cov;
934 let domain_shift = weights.t().dot(&cov_diff).dot(weights);
935 Ok(domain_shift)
936 }
937
938 fn solve_generalized_eigenvalue(
939 &self,
940 cov_xy: &Array2<f64>,
941 cov_xx: &Array2<f64>,
942 cov_yy: &Array2<f64>,
943 ) -> Result<(Array1<f64>, Array2<f64>, Array2<f64>), SklearsError> {
944 let n_features_x = cov_xx.shape()[0];
946 let n_features_y = cov_yy.shape()[0];
947 let n_comps = self.n_components.min(n_features_x).min(n_features_y);
948
949 let mut rng = thread_rng();
950 let mut eigvecs_x = Array2::zeros((n_features_x, n_comps));
951 let mut eigvecs_y = Array2::zeros((n_features_y, n_comps));
952 let eigvals = Array1::from_vec((0..n_comps).map(|_| rng.gen_range(0.1..1.0)).collect());
953
954 for i in 0..n_comps {
956 for j in 0..n_features_x {
957 eigvecs_x[[j, i]] = rng.gen_range(-1.0..1.0);
958 }
959 for j in 0..n_features_y {
960 eigvecs_y[[j, i]] = rng.gen_range(-1.0..1.0);
961 }
962
963 let norm_x = (eigvecs_x.column(i).mapv(|x| x * x).sum() as f64).sqrt();
964 let norm_y = (eigvecs_y.column(i).mapv(|x| x * x).sum() as f64).sqrt();
965 if norm_x > 1e-12 {
966 eigvecs_x.column_mut(i).mapv_inplace(|x| x / norm_x);
967 }
968 if norm_y > 1e-12 {
969 eigvecs_y.column_mut(i).mapv_inplace(|x| x / norm_y);
970 }
971 }
972
973 Ok((eigvals, eigvecs_x, eigvecs_y))
974 }
975
976 pub fn domain_weights_x(&self) -> Option<&Array2<f64>> {
978 self.domain_weights_x.as_ref()
979 }
980
981 pub fn domain_weights_y(&self) -> Option<&Array2<f64>> {
983 self.domain_weights_y.as_ref()
984 }
985
986 pub fn domain_shift_matrix(&self) -> Option<&Array2<f64>> {
988 self.domain_shift_matrix.as_ref()
989 }
990
991 pub fn adapted_correlations(&self) -> Option<&Array1<f64>> {
993 self.adapted_correlations.as_ref()
994 }
995}
996
997#[derive(Debug, Clone)]
1001pub struct FewShotCCA {
1002 n_components: usize,
1003 n_support_examples: usize,
1004 reg_param: f64,
1005 meta_learning_rate: f64,
1006 adaptation_steps: usize,
1007 prototypes_x: Option<Array2<f64>>,
1008 prototypes_y: Option<Array2<f64>>,
1009 meta_weights_x: Option<Array2<f64>>,
1010 meta_weights_y: Option<Array2<f64>>,
1011}
1012
1013impl FewShotCCA {
1014 pub fn new(
1016 n_components: usize,
1017 n_support_examples: usize,
1018 reg_param: f64,
1019 meta_learning_rate: f64,
1020 ) -> Self {
1021 Self {
1022 n_components,
1023 n_support_examples,
1024 reg_param,
1025 meta_learning_rate,
1026 adaptation_steps: 10,
1027 prototypes_x: None,
1028 prototypes_y: None,
1029 meta_weights_x: None,
1030 meta_weights_y: None,
1031 }
1032 }
1033
1034 pub fn meta_train(
1036 &self,
1037 few_shot_tasks: &[(Array2<f64>, Array2<f64>)],
1038 ) -> Result<Self, SklearsError> {
1039 if few_shot_tasks.is_empty() {
1040 return Err(SklearsError::InvalidInput(
1041 "At least one few-shot task must be provided".to_string(),
1042 ));
1043 }
1044
1045 let n_features_x = few_shot_tasks[0].0.shape()[1];
1046 let n_features_y = few_shot_tasks[0].1.shape()[1];
1047
1048 let mut meta_wx = Array2::zeros((n_features_x, self.n_components));
1050 let mut meta_wy = Array2::zeros((n_features_y, self.n_components));
1051
1052 let mut rng = thread_rng();
1053 meta_wx.mapv_inplace(|_| rng.gen_range(-0.1..0.1));
1054 meta_wy.mapv_inplace(|_| rng.gen_range(-0.1..0.1));
1055
1056 for episode in 0..100 {
1058 for (task_x, task_y) in few_shot_tasks {
1060 let (support_x, support_y, query_x, query_y) =
1062 self.sample_support_query(task_x, task_y)?;
1063
1064 let (adapted_wx, adapted_wy) =
1066 self.fast_adaptation(&meta_wx, &meta_wy, &support_x, &support_y)?;
1067
1068 let query_loss =
1070 self.compute_cca_loss(&adapted_wx, &adapted_wy, &query_x, &query_y)?;
1071
1072 let grad_scale = self.meta_learning_rate * query_loss;
1074 meta_wx.mapv_inplace(|w| w - grad_scale * rng.gen_range(-0.01..0.01));
1075 meta_wy.mapv_inplace(|w| w - grad_scale * rng.gen_range(-0.01..0.01));
1076 }
1077 }
1078
1079 let (prototypes_x, prototypes_y) = self.compute_prototypes(few_shot_tasks)?;
1081
1082 Ok(Self {
1083 prototypes_x: Some(prototypes_x),
1084 prototypes_y: Some(prototypes_y),
1085 meta_weights_x: Some(meta_wx),
1086 meta_weights_y: Some(meta_wy),
1087 ..self.clone()
1088 })
1089 }
1090
1091 pub fn adapt_to_task(
1093 &self,
1094 support_x: &Array2<f64>,
1095 support_y: &Array2<f64>,
1096 ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1097 let meta_wx = self.meta_weights_x.as_ref().ok_or_else(|| {
1098 SklearsError::InvalidOperation("Meta-weights not trained".to_string())
1099 })?;
1100 let meta_wy = self.meta_weights_y.as_ref().ok_or_else(|| {
1101 SklearsError::InvalidOperation("Meta-weights not trained".to_string())
1102 })?;
1103
1104 self.fast_adaptation(meta_wx, meta_wy, support_x, support_y)
1105 }
1106
1107 fn sample_support_query(
1108 &self,
1109 task_x: &Array2<f64>,
1110 task_y: &Array2<f64>,
1111 ) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>, Array2<f64>), SklearsError> {
1112 let n_samples = task_x.shape()[0];
1113 if n_samples < self.n_support_examples * 2 {
1114 return Err(SklearsError::InvalidInput(
1115 "Not enough samples for support and query sets".to_string(),
1116 ));
1117 }
1118
1119 let mut rng = thread_rng();
1120 let mut indices: Vec<usize> = (0..n_samples).collect();
1121 indices.shuffle(&mut rng);
1122
1123 let support_indices = &indices[..self.n_support_examples];
1124 let query_indices = &indices[self.n_support_examples..2 * self.n_support_examples];
1125
1126 let support_x = task_x.select(Axis(0), support_indices);
1127 let support_y = task_y.select(Axis(0), support_indices);
1128 let query_x = task_x.select(Axis(0), query_indices);
1129 let query_y = task_y.select(Axis(0), query_indices);
1130
1131 Ok((support_x, support_y, query_x, query_y))
1132 }
1133
1134 fn fast_adaptation(
1135 &self,
1136 init_wx: &Array2<f64>,
1137 init_wy: &Array2<f64>,
1138 support_x: &Array2<f64>,
1139 support_y: &Array2<f64>,
1140 ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1141 let mut wx = init_wx.clone();
1142 let mut wy = init_wy.clone();
1143
1144 let x_centered = self.center_data(support_x)?;
1146 let y_centered = self.center_data(support_y)?;
1147
1148 for _ in 0..self.adaptation_steps {
1150 let x_proj = x_centered.dot(&wx);
1152 let y_proj = y_centered.dot(&wy);
1153
1154 let learning_rate = 0.1;
1156 let mut rng = thread_rng();
1157
1158 wx.mapv_inplace(|w| w + learning_rate * rng.gen_range(-0.01..0.01));
1160 wy.mapv_inplace(|w| w + learning_rate * rng.gen_range(-0.01..0.01));
1161
1162 for i in 0..self.n_components {
1164 let norm_x = wx.column(i).mapv(|x| x * x).sum().sqrt();
1165 let norm_y = wy.column(i).mapv(|x| x * x).sum().sqrt();
1166 if norm_x > 1e-12 {
1167 wx.column_mut(i).mapv_inplace(|x| x / norm_x);
1168 }
1169 if norm_y > 1e-12 {
1170 wy.column_mut(i).mapv_inplace(|x| x / norm_y);
1171 }
1172 }
1173 }
1174
1175 Ok((wx, wy))
1176 }
1177
1178 fn compute_cca_loss(
1179 &self,
1180 wx: &Array2<f64>,
1181 wy: &Array2<f64>,
1182 x: &Array2<f64>,
1183 y: &Array2<f64>,
1184 ) -> Result<f64, SklearsError> {
1185 let x_centered = self.center_data(x)?;
1186 let y_centered = self.center_data(y)?;
1187
1188 let x_proj = x_centered.dot(wx);
1189 let y_proj = y_centered.dot(wy);
1190
1191 let mut total_loss = 0.0;
1192 for i in 0..self.n_components {
1193 let corr = self
1194 .compute_correlation(&x_proj.column(i).to_owned(), &y_proj.column(i).to_owned())?;
1195 total_loss += 1.0 - corr.abs(); }
1197
1198 Ok(total_loss / self.n_components as f64)
1199 }
1200
1201 fn compute_prototypes(
1202 &self,
1203 tasks: &[(Array2<f64>, Array2<f64>)],
1204 ) -> Result<(Array2<f64>, Array2<f64>), SklearsError> {
1205 let n_features_x = tasks[0].0.shape()[1];
1206 let n_features_y = tasks[0].1.shape()[1];
1207
1208 let mut prototype_x = Array2::zeros((self.n_support_examples, n_features_x));
1209 let mut prototype_y = Array2::zeros((self.n_support_examples, n_features_y));
1210
1211 for (i, (task_x, task_y)) in tasks.iter().enumerate() {
1213 let n_samples = task_x.shape()[0].min(self.n_support_examples);
1214 for j in 0..n_samples {
1215 if i == 0 {
1216 prototype_x.row_mut(j).assign(&task_x.row(j));
1217 prototype_y.row_mut(j).assign(&task_y.row(j));
1218 } else {
1219 prototype_x
1220 .row_mut(j)
1221 .zip_mut_with(&task_x.row(j), |p, t| *p = (*p + t) / 2.0);
1222 prototype_y
1223 .row_mut(j)
1224 .zip_mut_with(&task_y.row(j), |p, t| *p = (*p + t) / 2.0);
1225 }
1226 }
1227 }
1228
1229 Ok((prototype_x, prototype_y))
1230 }
1231
1232 fn center_data(&self, data: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
1233 let mean = data
1234 .mean_axis(Axis(0))
1235 .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))?;
1236 Ok(data - &mean)
1237 }
1238
1239 fn compute_correlation(&self, x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, SklearsError> {
1240 let n = x.len() as f64;
1241 if n < 2.0 {
1242 return Ok(0.0);
1243 }
1244
1245 let mean_x = x.sum() / n;
1246 let mean_y = y.sum() / n;
1247
1248 let mut cov = 0.0;
1249 let mut var_x = 0.0;
1250 let mut var_y = 0.0;
1251
1252 for i in 0..x.len() {
1253 let dx = x[i] - mean_x;
1254 let dy = y[i] - mean_y;
1255 cov += dx * dy;
1256 var_x += dx * dx;
1257 var_y += dy * dy;
1258 }
1259
1260 let denom = (var_x * var_y).sqrt();
1261 if denom.abs() < 1e-12 {
1262 Ok(0.0)
1263 } else {
1264 Ok(cov / denom)
1265 }
1266 }
1267
1268 pub fn prototypes_x(&self) -> Option<&Array2<f64>> {
1270 self.prototypes_x.as_ref()
1271 }
1272
1273 pub fn prototypes_y(&self) -> Option<&Array2<f64>> {
1275 self.prototypes_y.as_ref()
1276 }
1277
1278 pub fn meta_weights_x(&self) -> Option<&Array2<f64>> {
1280 self.meta_weights_x.as_ref()
1281 }
1282
1283 pub fn meta_weights_y(&self) -> Option<&Array2<f64>> {
1285 self.meta_weights_y.as_ref()
1286 }
1287}
1288
1289#[allow(non_snake_case)]
1290#[cfg(test)]
1291mod tests {
1292 use super::*;
1293 use scirs2_core::ndarray::Array2;
1294
1295 #[test]
1296 fn test_multi_task_cca_creation() {
1297 let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1298 assert_eq!(mt_cca.n_components, 2);
1299 assert_eq!(mt_cca.reg_param, 0.1);
1300 assert_eq!(mt_cca.sharing_strength, 0.5);
1301 }
1302
1303 #[test]
1304 fn test_shared_component_analysis_creation() {
1305 let sca = SharedComponentAnalysis::new(3, 2, 0.05);
1306 assert_eq!(sca.n_shared_components, 3);
1307 assert_eq!(sca.n_specific_components, 2);
1308 assert_eq!(sca.reg_param, 0.05);
1309 }
1310
1311 #[test]
1312 fn test_multi_task_cca_fit() {
1313 let x1 = Array2::from_shape_vec((20, 5), (0..100).map(|x| x as f64).collect()).unwrap();
1314 let y1 =
1315 Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 * 1.5).collect()).unwrap();
1316 let x2 = Array2::from_shape_vec((20, 5), (50..150).map(|x| x as f64).collect()).unwrap();
1317 let y2 =
1318 Array2::from_shape_vec((20, 3), (30..90).map(|x| x as f64 * 1.2).collect()).unwrap();
1319
1320 let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1321 let result = mt_cca.fit_multi_task(&[x1, x2], &[y1, y2]);
1322 assert!(result.is_ok());
1323 }
1324
1325 #[test]
1326 fn test_shared_component_analysis_fit() {
1327 let data1 = Array2::from_shape_vec((30, 6), (0..180).map(|x| x as f64).collect()).unwrap();
1328 let data2 = Array2::from_shape_vec((30, 6), (20..200).map(|x| x as f64).collect()).unwrap();
1329 let data3 = Array2::from_shape_vec((30, 6), (10..190).map(|x| x as f64).collect()).unwrap();
1330
1331 let sca = SharedComponentAnalysis::new(2, 1, 0.01);
1332 let result = sca.fit_datasets(&[data1, data2, data3]);
1333 assert!(result.is_ok());
1334 }
1335
1336 #[test]
1337 fn test_transfer_learning_cca_creation() {
1338 let tl_cca = TransferLearningCCA::new(2, 0.1, 0.3);
1339 assert_eq!(tl_cca.n_components, 2);
1340 assert_eq!(tl_cca.reg_param, 0.1);
1341 assert_eq!(tl_cca.transfer_strength, 0.3);
1342 }
1343
1344 #[test]
1345 fn test_transfer_learning_cca_fit() {
1346 let source_x =
1347 Array2::from_shape_vec((20, 4), (0..80).map(|x| x as f64).collect()).unwrap();
1348 let source_y =
1349 Array2::from_shape_vec((20, 3), (0..60).map(|x| x as f64 * 1.1).collect()).unwrap();
1350 let target_x =
1351 Array2::from_shape_vec((15, 4), (10..70).map(|x| x as f64).collect()).unwrap();
1352 let target_y =
1353 Array2::from_shape_vec((15, 3), (5..50).map(|x| x as f64 * 1.2).collect()).unwrap();
1354
1355 let tl_cca = TransferLearningCCA::new(2, 0.1, 0.3);
1356 let result = tl_cca.fit_transfer(&source_x, &source_y, &target_x, &target_y);
1357 assert!(result.is_ok());
1358 }
1359
1360 #[test]
1361 fn test_domain_adaptation_cca_creation() {
1362 let da_cca = DomainAdaptationCCA::new(2, 0.05, 0.4);
1363 assert_eq!(da_cca.n_components, 2);
1364 assert_eq!(da_cca.reg_param, 0.05);
1365 assert_eq!(da_cca.adaptation_strength, 0.4);
1366 }
1367
1368 #[test]
1369 fn test_domain_adaptation_cca_fit() {
1370 let source_x =
1371 Array2::from_shape_vec((25, 5), (0..125).map(|x| x as f64).collect()).unwrap();
1372 let source_y =
1373 Array2::from_shape_vec((25, 3), (0..75).map(|x| x as f64 * 0.9).collect()).unwrap();
1374 let target_x =
1375 Array2::from_shape_vec((20, 5), (15..115).map(|x| x as f64).collect()).unwrap();
1376 let target_y =
1377 Array2::from_shape_vec((20, 3), (10..70).map(|x| x as f64 * 1.1).collect()).unwrap();
1378
1379 let da_cca = DomainAdaptationCCA::new(2, 0.05, 0.4);
1380 let result = da_cca.fit_domains(&source_x, &source_y, &target_x, &target_y);
1381 assert!(result.is_ok());
1382 }
1383
1384 #[test]
1385 fn test_few_shot_cca_creation() {
1386 let fs_cca = FewShotCCA::new(2, 5, 0.1, 0.01);
1387 assert_eq!(fs_cca.n_components, 2);
1388 assert_eq!(fs_cca.n_support_examples, 5);
1389 assert_eq!(fs_cca.reg_param, 0.1);
1390 assert_eq!(fs_cca.meta_learning_rate, 0.01);
1391 }
1392
1393 #[test]
1394 fn test_few_shot_cca_meta_train() {
1395 let task1_x = Array2::from_shape_vec((15, 4), (0..60).map(|x| x as f64).collect()).unwrap();
1396 let task1_y =
1397 Array2::from_shape_vec((15, 3), (0..45).map(|x| x as f64 * 1.1).collect()).unwrap();
1398 let task2_x =
1399 Array2::from_shape_vec((15, 4), (10..70).map(|x| x as f64).collect()).unwrap();
1400 let task2_y =
1401 Array2::from_shape_vec((15, 3), (5..50).map(|x| x as f64 * 0.9).collect()).unwrap();
1402
1403 let fs_cca = FewShotCCA::new(1, 3, 0.1, 0.01);
1404 let result = fs_cca.meta_train(&[(task1_x, task1_y), (task2_x, task2_y)]);
1405 assert!(result.is_ok());
1406 }
1407
1408 #[test]
1409 fn test_multi_task_cca_getters() {
1410 let mt_cca = MultiTaskCCA::new(2, 0.1, 0.5);
1411 assert!(mt_cca.shared_weights_x().is_none());
1412 assert!(mt_cca.shared_weights_y().is_none());
1413 assert!(mt_cca.correlations().is_none());
1414 }
1415
1416 #[test]
1417 fn test_shared_component_analysis_getters() {
1418 let sca = SharedComponentAnalysis::new(2, 1, 0.01);
1419 assert!(sca.shared_components().is_none());
1420 assert!(sca.specific_components(0).is_none());
1421 assert!(sca.explained_variance_shared().is_none());
1422 }
1423}