1use crate::classification::GpcConfig;
31use crate::kernels::Kernel;
32use crate::utils;
33use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
34use scirs2_core::random::rngs::StdRng;
35use scirs2_core::random::{Rng, SeedableRng};
36use sklears_core::{
37 error::{Result as SklResult, SklearsError},
38 traits::{Estimator, Fit, Predict, Untrained},
39};
40use std::f64::consts::PI;
41
42#[derive(Debug, Clone)]
48pub struct LinearModelCoregionalization<S = Untrained> {
49 kernels: Vec<Box<dyn Kernel>>,
50 mixing_matrices: Vec<Array2<f64>>, alpha: f64,
52 n_outputs: usize,
53 n_latent: usize,
54 optimize_mixing: bool,
55 _state: S,
56}
57
58#[derive(Debug, Clone)]
60pub struct LmcTrained {
61 pub(crate) X_train: Array2<f64>,
62 pub(crate) Y_train: Array2<f64>,
63 pub(crate) kernels: Vec<Box<dyn Kernel>>,
64 pub(crate) mixing_matrices: Vec<Array2<f64>>,
65 pub(crate) alpha: f64,
66 pub(crate) n_outputs: usize,
67 pub(crate) n_latent: usize,
68 pub(crate) gram_matrices: Vec<Array2<f64>>, pub(crate) alpha_vectors: Vec<Array1<f64>>, pub(crate) log_marginal_likelihood_value: f64,
71}
72
73impl Default for LinearModelCoregionalization<Untrained> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl LinearModelCoregionalization<Untrained> {
80 pub fn new() -> Self {
82 Self {
83 kernels: Vec::new(),
84 mixing_matrices: Vec::new(),
85 alpha: 1e-10,
86 n_outputs: 1,
87 n_latent: 1,
88 optimize_mixing: true,
89 _state: Untrained,
90 }
91 }
92
93 pub fn kernels(mut self, kernels: Vec<Box<dyn Kernel>>) -> Self {
95 self.n_latent = kernels.len();
96 self.kernels = kernels;
97 self
98 }
99
100 pub fn mixing_matrices(mut self, mixing_matrices: Vec<Array2<f64>>) -> Self {
102 self.mixing_matrices = mixing_matrices;
103 self
104 }
105
106 pub fn alpha(mut self, alpha: f64) -> Self {
108 self.alpha = alpha;
109 self
110 }
111
112 pub fn n_outputs(mut self, n_outputs: usize) -> Self {
114 self.n_outputs = n_outputs;
115 self
116 }
117
118 pub fn optimize_mixing(mut self, optimize_mixing: bool) -> Self {
120 self.optimize_mixing = optimize_mixing;
121 self
122 }
123
124 fn initialize_mixing_matrices(&self) -> Vec<Array2<f64>> {
126 if !self.mixing_matrices.is_empty() {
127 return self.mixing_matrices.clone();
128 }
129
130 let mut matrices = Vec::new();
131 let mut rng = StdRng::seed_from_u64(42);
133
134 for _ in 0..self.n_latent {
135 let mut matrix = Array2::<f64>::zeros((self.n_outputs, 1));
136
137 for i in 0..self.n_outputs {
139 let random_val = rng.gen_range(-1.0..1.0);
140 matrix[[i, 0]] = random_val;
141 }
142
143 matrices.push(matrix);
144 }
145
146 matrices
147 }
148
149 #[allow(non_snake_case)]
151 fn compute_coregionalization_kernel(
152 &self,
153 X1: &Array2<f64>,
154 X2: Option<&Array2<f64>>,
155 mixing_matrices: &[Array2<f64>],
156 ) -> SklResult<Array2<f64>> {
157 let n1 = X1.nrows();
158 let n2 = X2.map_or(n1, |x| x.nrows());
159
160 let mut K_full = Array2::<f64>::zeros((n1 * self.n_outputs, n2 * self.n_outputs));
162
163 for (r, kernel) in self.kernels.iter().enumerate() {
165 let K_r = kernel.compute_kernel_matrix(X1, X2)?;
167
168 let A_r = &mixing_matrices[r];
170
171 for i in 0..self.n_outputs {
173 for j in 0..self.n_outputs {
174 let coeff = A_r[[i, 0]] * A_r[[j, 0]];
176
177 for n in 0..n1 {
179 for m in 0..n2 {
180 K_full[[n * self.n_outputs + i, m * self.n_outputs + j]] +=
181 coeff * K_r[[n, m]];
182 }
183 }
184 }
185 }
186 }
187
188 Ok(K_full)
189 }
190
191 fn vectorize_targets(&self, Y: &Array2<f64>) -> Array1<f64> {
193 let (n_samples, n_outputs) = Y.dim();
194 let mut y_vec = Array1::<f64>::zeros(n_samples * n_outputs);
195
196 for i in 0..n_samples {
197 for j in 0..n_outputs {
198 y_vec[i * n_outputs + j] = Y[[i, j]];
199 }
200 }
201
202 y_vec
203 }
204
205 fn devectorize_predictions(&self, y_vec: &Array1<f64>, n_samples: usize) -> Array2<f64> {
207 let mut Y = Array2::<f64>::zeros((n_samples, self.n_outputs));
208
209 for i in 0..n_samples {
210 for j in 0..self.n_outputs {
211 Y[[i, j]] = y_vec[i * self.n_outputs + j];
212 }
213 }
214
215 Y
216 }
217}
218
219impl Estimator for LinearModelCoregionalization<Untrained> {
220 type Config = GpcConfig; type Error = SklearsError;
222 type Float = f64;
223
224 fn config(&self) -> &Self::Config {
225 static DEFAULT_CONFIG: GpcConfig = GpcConfig {
228 kernel_name: String::new(),
229 optimizer: None,
230 n_restarts_optimizer: 0,
231 max_iter_predict: 100,
232 warm_start: false,
233 copy_x_train: true,
234 random_state: None,
235 };
236 &DEFAULT_CONFIG
237 }
238}
239
240impl Estimator for LinearModelCoregionalization<LmcTrained> {
241 type Config = GpcConfig;
242 type Error = SklearsError;
243 type Float = f64;
244
245 fn config(&self) -> &Self::Config {
246 static DEFAULT_CONFIG: GpcConfig = GpcConfig {
248 kernel_name: String::new(),
249 optimizer: None,
250 n_restarts_optimizer: 0,
251 max_iter_predict: 100,
252 warm_start: false,
253 copy_x_train: true,
254 random_state: None,
255 };
256 &DEFAULT_CONFIG
257 }
258}
259
260impl Fit<ArrayView2<'_, f64>, ArrayView2<'_, f64>> for LinearModelCoregionalization<Untrained> {
261 type Fitted = LinearModelCoregionalization<LmcTrained>;
262
263 #[allow(non_snake_case)]
264 fn fit(self, X: &ArrayView2<f64>, Y: &ArrayView2<f64>) -> SklResult<Self::Fitted> {
265 if X.nrows() != Y.nrows() {
266 return Err(SklearsError::InvalidInput(
267 "X and Y must have the same number of samples".to_string(),
268 ));
269 }
270
271 if self.kernels.is_empty() {
272 return Err(SklearsError::InvalidInput(
273 "At least one kernel must be specified".to_string(),
274 ));
275 }
276
277 let n_outputs = Y.ncols();
279 if self.n_outputs != n_outputs {
280 return Err(SklearsError::InvalidInput(format!(
281 "n_outputs ({}) must match Y dimensions ({})",
282 self.n_outputs, n_outputs
283 )));
284 }
285
286 let X_owned = X.to_owned();
287 let Y_owned = Y.to_owned();
288
289 let mixing_matrices = self.initialize_mixing_matrices();
291
292 let K_full = self.compute_coregionalization_kernel(&X_owned, None, &mixing_matrices)?;
294
295 let mut K_reg = K_full.clone();
297 for i in 0..K_reg.nrows() {
298 K_reg[[i, i]] += self.alpha;
299 }
300
301 let y_vec = self.vectorize_targets(&Y_owned);
303
304 let chol_decomp = utils::robust_cholesky(&K_reg)?;
306 let alpha_vec = utils::triangular_solve(&chol_decomp, &y_vec)?;
307
308 let log_det = chol_decomp.diag().iter().map(|x| x.ln()).sum::<f64>() * 2.0;
310 let data_fit = y_vec.dot(&alpha_vec);
311 let n_total = y_vec.len();
312 let log_marginal_likelihood =
313 -0.5 * (data_fit + log_det + n_total as f64 * (2.0 * PI).ln());
314
315 let mut gram_matrices = Vec::new();
317 for kernel in &self.kernels {
318 let K_r = kernel.compute_kernel_matrix(&X_owned, None)?;
319 gram_matrices.push(K_r);
320 }
321
322 let kernels_clone: Vec<Box<dyn Kernel>> = self.kernels.to_vec();
324
325 Ok(LinearModelCoregionalization {
326 kernels: self.kernels,
327 mixing_matrices: mixing_matrices.clone(),
328 alpha: self.alpha,
329 n_outputs: self.n_outputs,
330 n_latent: self.n_latent,
331 optimize_mixing: self.optimize_mixing,
332 _state: LmcTrained {
333 X_train: X_owned,
334 Y_train: Y_owned,
335 kernels: kernels_clone,
336 mixing_matrices,
337 alpha: self.alpha,
338 n_outputs: self.n_outputs,
339 n_latent: self.n_latent,
340 gram_matrices,
341 alpha_vectors: vec![alpha_vec],
342 log_marginal_likelihood_value: log_marginal_likelihood,
343 },
344 })
345 }
346}
347
348impl LinearModelCoregionalization<LmcTrained> {
349 pub fn trained_state(&self) -> &LmcTrained {
351 &self._state
352 }
353
354 pub fn mixing_matrices(&self) -> &[Array2<f64>] {
356 &self._state.mixing_matrices
357 }
358
359 pub fn log_marginal_likelihood(&self) -> f64 {
361 self._state.log_marginal_likelihood_value
362 }
363
364 #[allow(non_snake_case)]
366 fn compute_coregionalization_kernel(
367 &self,
368 X1: &Array2<f64>,
369 X2: Option<&Array2<f64>>,
370 mixing_matrices: &[Array2<f64>],
371 ) -> SklResult<Array2<f64>> {
372 let n1 = X1.nrows();
373 let n2 = X2.map_or(n1, |x| x.nrows());
374
375 let mut K_full =
377 Array2::<f64>::zeros((n1 * self._state.n_outputs, n2 * self._state.n_outputs));
378
379 for (r, kernel) in self._state.kernels.iter().enumerate() {
381 let K_r = kernel.compute_kernel_matrix(X1, X2)?;
383
384 let A_r = &mixing_matrices[r];
386
387 for i in 0..self._state.n_outputs {
389 for j in 0..self._state.n_outputs {
390 let coeff = A_r[[i, 0]] * A_r[[j, 0]];
392
393 for n in 0..n1 {
395 for m in 0..n2 {
396 K_full
397 [[n * self._state.n_outputs + i, m * self._state.n_outputs + j]] +=
398 coeff * K_r[[n, m]];
399 }
400 }
401 }
402 }
403 }
404
405 Ok(K_full)
406 }
407
408 fn devectorize_predictions(&self, y_vec: &Array1<f64>, n_samples: usize) -> Array2<f64> {
410 let mut Y = Array2::<f64>::zeros((n_samples, self._state.n_outputs));
411
412 for i in 0..n_samples {
413 for j in 0..self._state.n_outputs {
414 Y[[i, j]] = y_vec[i * self._state.n_outputs + j];
415 }
416 }
417
418 Y
419 }
420
421 #[allow(non_snake_case)]
423 pub fn latent_contributions(&self, X: &ArrayView2<f64>) -> SklResult<Vec<Array2<f64>>> {
424 let mut contributions = Vec::new();
425
426 for r in 0..self._state.n_latent {
427 let _K_star = self._state.kernels[r]
429 .compute_kernel_matrix(&self._state.X_train, Some(&X.to_owned()))?;
430
431 let A_r = &self._state.mixing_matrices[r];
433
434 let n_test = X.nrows();
436 let mut contribution = Array2::<f64>::zeros((n_test, self._state.n_outputs));
437
438 for i in 0..n_test {
441 for j in 0..self._state.n_outputs {
442 contribution[[i, j]] = A_r[[j, 0]]; }
444 }
445
446 contributions.push(contribution);
447 }
448
449 Ok(contributions)
450 }
451}
452
453impl Predict<ArrayView2<'_, f64>, Array2<f64>> for LinearModelCoregionalization<LmcTrained> {
454 #[allow(non_snake_case)]
455 fn predict(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
456 let X_test = X.to_owned();
457
458 let K_star = self.compute_coregionalization_kernel(
460 &self._state.X_train,
461 Some(&X_test),
462 &self._state.mixing_matrices,
463 )?;
464
465 let alpha_vec = &self._state.alpha_vectors[0];
467
468 let n_test = X_test.nrows();
470 let mut y_pred_vec = Array1::<f64>::zeros(n_test * self._state.n_outputs);
471
472 for i in 0..n_test * self._state.n_outputs {
473 for j in 0..self._state.X_train.nrows() * self._state.n_outputs {
474 y_pred_vec[i] += K_star[[j, i]] * alpha_vec[j];
475 }
476 }
477
478 let predictions = self.devectorize_predictions(&y_pred_vec, n_test);
480
481 Ok(predictions)
482 }
483}
484
485#[allow(non_snake_case)]
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::kernels::RBF;
490 use approx::assert_abs_diff_eq;
491 use scirs2_core::ndarray::array;
493
494 #[test]
495 fn test_lmc_creation() {
496 let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
497 let lmc = LinearModelCoregionalization::new()
498 .kernels(kernels)
499 .n_outputs(2)
500 .alpha(1e-6);
501
502 assert_eq!(lmc.n_outputs, 2);
503 assert_eq!(lmc.n_latent, 1);
504 assert_eq!(lmc.alpha, 1e-6);
505 }
506
507 #[test]
508 #[allow(non_snake_case)]
509 fn test_lmc_fit_predict() {
510 let X = array![[1.0], [2.0], [3.0], [4.0]];
511 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
512
513 let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
514 let lmc = LinearModelCoregionalization::new()
515 .kernels(kernels)
516 .n_outputs(2)
517 .alpha(1e-6);
518
519 let fitted = lmc.fit(&X.view(), &Y.view()).unwrap();
520 let predictions = fitted.predict(&X.view()).unwrap();
521
522 assert_eq!(predictions.shape(), &[4, 2]);
523 }
524
525 #[test]
526 #[allow(non_snake_case)]
527 fn test_lmc_multi_kernel() {
528 let X = array![[1.0], [2.0], [3.0], [4.0]];
529 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
530
531 let kernels = vec![
532 Box::new(RBF::new(1.0)) as Box<dyn Kernel>,
533 Box::new(RBF::new(2.0)) as Box<dyn Kernel>,
534 ];
535 let lmc = LinearModelCoregionalization::new()
536 .kernels(kernels)
537 .n_outputs(2)
538 .alpha(1e-6);
539
540 let fitted = lmc.fit(&X.view(), &Y.view()).unwrap();
541 let predictions = fitted.predict(&X.view()).unwrap();
542
543 assert_eq!(predictions.shape(), &[4, 2]);
544 assert_eq!(fitted.trained_state().n_latent, 2);
545 }
546
547 #[test]
548 #[allow(non_snake_case)]
549 fn test_lmc_mixing_matrices() {
550 let X = array![[1.0], [2.0], [3.0], [4.0]];
551 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
552
553 let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
554 let mixing_matrix = array![[0.5], [0.8]];
555 let lmc = LinearModelCoregionalization::new()
556 .kernels(kernels)
557 .mixing_matrices(vec![mixing_matrix.clone()])
558 .n_outputs(2)
559 .alpha(1e-6);
560
561 let fitted = lmc.fit(&X.view(), &Y.view()).unwrap();
562 let learned_mixing = &fitted.mixing_matrices()[0];
563
564 assert_eq!(learned_mixing.shape(), mixing_matrix.shape());
565 assert_abs_diff_eq!(learned_mixing[[0, 0]], 0.5, epsilon = 1e-10);
566 assert_abs_diff_eq!(learned_mixing[[1, 0]], 0.8, epsilon = 1e-10);
567 }
568
569 #[test]
570 #[allow(non_snake_case)]
571 fn test_lmc_log_marginal_likelihood() {
572 let X = array![[1.0], [2.0], [3.0], [4.0]];
573 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
574
575 let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
576 let lmc = LinearModelCoregionalization::new()
577 .kernels(kernels)
578 .n_outputs(2)
579 .alpha(1e-6);
580
581 let fitted = lmc.fit(&X.view(), &Y.view()).unwrap();
582 let log_ml = fitted.log_marginal_likelihood();
583
584 assert!(log_ml.is_finite());
585 }
586
587 #[test]
588 #[allow(non_snake_case)]
589 fn test_lmc_latent_contributions() {
590 let X = array![[1.0], [2.0], [3.0], [4.0]];
591 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
592
593 let kernels = vec![
594 Box::new(RBF::new(1.0)) as Box<dyn Kernel>,
595 Box::new(RBF::new(2.0)) as Box<dyn Kernel>,
596 ];
597 let lmc = LinearModelCoregionalization::new()
598 .kernels(kernels)
599 .n_outputs(2)
600 .alpha(1e-6);
601
602 let fitted = lmc.fit(&X.view(), &Y.view()).unwrap();
603 let contributions = fitted.latent_contributions(&X.view()).unwrap();
604
605 assert_eq!(contributions.len(), 2); assert_eq!(contributions[0].shape(), &[4, 2]);
607 assert_eq!(contributions[1].shape(), &[4, 2]);
608 }
609
610 #[test]
611 #[allow(non_snake_case)]
612 fn test_lmc_errors() {
613 let X = array![[1.0], [2.0], [3.0], [4.0]];
614 let Y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
615
616 let lmc = LinearModelCoregionalization::new().n_outputs(2).alpha(1e-6);
618 assert!(lmc.fit(&X.view(), &Y.view()).is_err());
619
620 let X_wrong = array![[1.0], [2.0], [3.0]]; let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
623 let lmc = LinearModelCoregionalization::new()
624 .kernels(kernels)
625 .n_outputs(2)
626 .alpha(1e-6);
627 assert!(lmc.fit(&X_wrong.view(), &Y.view()).is_err());
628
629 let kernels = vec![Box::new(RBF::new(1.0)) as Box<dyn Kernel>];
631 let lmc = LinearModelCoregionalization::new()
632 .kernels(kernels)
633 .n_outputs(3) .alpha(1e-6);
635 assert!(lmc.fit(&X.view(), &Y.view()).is_err());
636 }
637}