1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19use sklears_core::{
22 error::{Result as SklResult, SklearsError},
23 traits::{Estimator, Untrained},
24};
25use std::collections::HashMap;
26use std::f64::consts::PI;
27
28use crate::kernels::Kernel;
29use crate::utils;
30
31#[derive(Debug, Clone)]
33pub struct MtgpConfig {
34 pub shared_kernel_name: String,
36 pub task_kernel_name: String,
38 pub alpha: f64,
40 pub shared_weight: f64,
42 pub task_weight: f64,
44 pub random_state: Option<u64>,
46}
47
48impl Default for MtgpConfig {
49 fn default() -> Self {
50 Self {
51 shared_kernel_name: "RBF".to_string(),
52 task_kernel_name: "RBF".to_string(),
53 alpha: 1e-10,
54 shared_weight: 1.0,
55 task_weight: 1.0,
56 random_state: None,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
100pub struct MultiTaskGaussianProcessRegressor<S = Untrained> {
101 shared_kernel: Option<Box<dyn Kernel>>,
102 task_kernel: Option<Box<dyn Kernel>>,
103 tasks: HashMap<String, (Array2<f64>, Array1<f64>)>, alpha: f64,
105 shared_weight: f64,
106 task_weight: f64,
107 _state: S,
108}
109
110#[derive(Debug, Clone)]
112pub struct MtgpTrained {
113 tasks: HashMap<String, (Array2<f64>, Array1<f64>)>,
114 shared_kernel: Box<dyn Kernel>,
115 task_kernel: Box<dyn Kernel>,
116 alpha: f64,
117 shared_weight: f64,
118 task_weight: f64,
119 alpha_vector: Array1<f64>, log_marginal_likelihood_values: HashMap<String, f64>, task_indices: HashMap<String, (usize, usize)>, all_X: Array2<f64>, all_y: Array1<f64>, }
125
126impl MultiTaskGaussianProcessRegressor<Untrained> {
127 pub fn new() -> Self {
129 Self {
130 shared_kernel: None,
131 task_kernel: None,
132 tasks: HashMap::new(),
133 alpha: 1e-10,
134 shared_weight: 1.0,
135 task_weight: 1.0,
136 _state: Untrained,
137 }
138 }
139
140 pub fn shared_kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
142 self.shared_kernel = Some(kernel);
143 self
144 }
145
146 pub fn task_kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
148 self.task_kernel = Some(kernel);
149 self
150 }
151
152 pub fn alpha(mut self, alpha: f64) -> Self {
154 self.alpha = alpha;
155 self
156 }
157
158 pub fn shared_weight(mut self, weight: f64) -> Self {
160 self.shared_weight = weight;
161 self
162 }
163
164 pub fn task_weight(mut self, weight: f64) -> Self {
166 self.task_weight = weight;
167 self
168 }
169
170 pub fn add_task(
172 mut self,
173 task_name: &str,
174 X: &ArrayView2<f64>,
175 y: &ArrayView1<f64>,
176 ) -> SklResult<Self> {
177 if X.nrows() != y.len() {
178 return Err(SklearsError::InvalidInput(
179 "X and y must have the same number of samples".to_string(),
180 ));
181 }
182
183 self.tasks
184 .insert(task_name.to_string(), (X.to_owned(), y.to_owned()));
185 Ok(self)
186 }
187
188 pub fn remove_task(mut self, task_name: &str) -> Self {
190 self.tasks.remove(task_name);
191 self
192 }
193
194 pub fn task_names(&self) -> Vec<String> {
196 self.tasks.keys().cloned().collect()
197 }
198
199 fn combine_task_data(
201 &self,
202 ) -> SklResult<(Array2<f64>, Array1<f64>, HashMap<String, (usize, usize)>)> {
203 if self.tasks.is_empty() {
204 return Err(SklearsError::InvalidInput(
205 "At least one task must be added".to_string(),
206 ));
207 }
208
209 let mut all_X_vec = Vec::new();
210 let mut all_y_vec = Vec::new();
211 let mut task_indices = HashMap::new();
212 let mut current_idx = 0;
213
214 let first_task = self.tasks.values().next().unwrap();
216 let n_features = first_task.0.ncols();
217
218 for (task_name, (X, y)) in &self.tasks {
219 if X.ncols() != n_features {
220 return Err(SklearsError::InvalidInput(
221 "All tasks must have the same number of features".to_string(),
222 ));
223 }
224
225 let n_samples = X.nrows();
226 task_indices.insert(task_name.clone(), (current_idx, current_idx + n_samples));
227
228 for i in 0..n_samples {
230 let mut row = Vec::new();
231 for j in 0..n_features {
232 row.push(X[[i, j]]);
233 }
234 all_X_vec.push(row);
235 }
236
237 for i in 0..n_samples {
239 all_y_vec.push(y[i]);
240 }
241
242 current_idx += n_samples;
243 }
244
245 let n_total = all_X_vec.len();
246 let mut all_X = Array2::<f64>::zeros((n_total, n_features));
247 let mut all_y = Array1::<f64>::zeros(n_total);
248
249 for (i, row) in all_X_vec.iter().enumerate() {
250 for (j, &val) in row.iter().enumerate() {
251 all_X[[i, j]] = val;
252 }
253 }
254
255 for (i, &val) in all_y_vec.iter().enumerate() {
256 all_y[i] = val;
257 }
258
259 Ok((all_X, all_y, task_indices))
260 }
261
262 #[allow(non_snake_case)]
264 fn compute_multitask_covariance(
265 &self,
266 X: &Array2<f64>,
267 task_indices: &HashMap<String, (usize, usize)>,
268 shared_kernel: &Box<dyn Kernel>,
269 task_kernel: &Box<dyn Kernel>,
270 ) -> SklResult<Array2<f64>> {
271 let n = X.nrows();
272 let mut K = Array2::<f64>::zeros((n, n));
273
274 let K_shared = shared_kernel.compute_kernel_matrix(X, None)?;
276
277 for i in 0..n {
279 for j in 0..n {
280 K[[i, j]] += self.shared_weight * self.shared_weight * K_shared[[i, j]];
281 }
282 }
283
284 let K_task = task_kernel.compute_kernel_matrix(X, None)?;
286
287 for (start_i, end_i) in task_indices.values() {
288 for i in *start_i..*end_i {
289 for j in *start_i..*end_i {
290 K[[i, j]] += self.task_weight * self.task_weight * K_task[[i, j]];
291 }
292 }
293 }
294
295 Ok(K)
296 }
297}
298
299impl Default for MultiTaskGaussianProcessRegressor<Untrained> {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl Estimator for MultiTaskGaussianProcessRegressor<Untrained> {
306 type Config = MtgpConfig;
307 type Error = SklearsError;
308 type Float = f64;
309
310 fn config(&self) -> &Self::Config {
311 static DEFAULT_CONFIG: MtgpConfig = MtgpConfig {
312 shared_kernel_name: String::new(),
313 task_kernel_name: String::new(),
314 alpha: 1e-10,
315 shared_weight: 1.0,
316 task_weight: 1.0,
317 random_state: None,
318 };
319 &DEFAULT_CONFIG
320 }
321}
322
323impl Estimator for MultiTaskGaussianProcessRegressor<MtgpTrained> {
324 type Config = MtgpConfig;
325 type Error = SklearsError;
326 type Float = f64;
327
328 fn config(&self) -> &Self::Config {
329 static DEFAULT_CONFIG: MtgpConfig = MtgpConfig {
330 shared_kernel_name: String::new(),
331 task_kernel_name: String::new(),
332 alpha: 1e-10,
333 shared_weight: 1.0,
334 task_weight: 1.0,
335 random_state: None,
336 };
337 &DEFAULT_CONFIG
338 }
339}
340
341impl MultiTaskGaussianProcessRegressor<Untrained> {
343 #[allow(non_snake_case)]
345 pub fn fit(self) -> SklResult<MultiTaskGaussianProcessRegressor<MtgpTrained>> {
346 let shared_kernel = self.shared_kernel.as_ref().ok_or_else(|| {
347 SklearsError::InvalidInput("Shared kernel must be specified".to_string())
348 })?;
349
350 let task_kernel = self.task_kernel.as_ref().ok_or_else(|| {
351 SklearsError::InvalidInput("Task kernel must be specified".to_string())
352 })?;
353
354 if self.tasks.is_empty() {
355 return Err(SklearsError::InvalidInput(
356 "At least one task must be added".to_string(),
357 ));
358 }
359
360 let (all_X, all_y, task_indices) = self.combine_task_data()?;
362
363 let K =
365 self.compute_multitask_covariance(&all_X, &task_indices, shared_kernel, task_kernel)?;
366
367 let mut K_reg = K.clone();
369 for i in 0..K_reg.nrows() {
370 K_reg[[i, i]] += self.alpha;
371 }
372
373 let chol_decomp = utils::robust_cholesky(&K_reg)?;
375 let alpha_vector = utils::triangular_solve(&chol_decomp, &all_y)?;
376
377 let mut log_marginal_likelihood_values = HashMap::new();
379 for (task_name, (start_idx, end_idx)) in &task_indices {
380 let task_size = end_idx - start_idx;
381 let task_y = all_y.slice(scirs2_core::ndarray::s![*start_idx..*end_idx]);
382 let task_alpha = alpha_vector.slice(scirs2_core::ndarray::s![*start_idx..*end_idx]);
383
384 let data_fit = task_y.dot(&task_alpha);
386 let log_ml = -0.5 * (data_fit + task_size as f64 * (2.0 * PI).ln());
387 log_marginal_likelihood_values.insert(task_name.clone(), log_ml);
388 }
389
390 Ok(MultiTaskGaussianProcessRegressor {
391 shared_kernel: None,
392 task_kernel: None,
393 tasks: self.tasks.clone(),
394 alpha: self.alpha,
395 shared_weight: self.shared_weight,
396 task_weight: self.task_weight,
397 _state: MtgpTrained {
398 tasks: self.tasks,
399 shared_kernel: shared_kernel.clone(),
400 task_kernel: task_kernel.clone(),
401 alpha: self.alpha,
402 shared_weight: self.shared_weight,
403 task_weight: self.task_weight,
404 alpha_vector,
405 log_marginal_likelihood_values,
406 task_indices,
407 all_X,
408 all_y,
409 },
410 })
411 }
412}
413
414impl MultiTaskGaussianProcessRegressor<MtgpTrained> {
415 pub fn trained_state(&self) -> &MtgpTrained {
417 &self._state
418 }
419
420 pub fn log_marginal_likelihood_task(&self, task_name: &str) -> Option<f64> {
422 self._state
423 .log_marginal_likelihood_values
424 .get(task_name)
425 .copied()
426 }
427
428 pub fn log_marginal_likelihoods(&self) -> &HashMap<String, f64> {
430 &self._state.log_marginal_likelihood_values
431 }
432
433 pub fn task_names(&self) -> Vec<String> {
435 self._state.tasks.keys().cloned().collect()
436 }
437
438 #[allow(non_snake_case)]
440 pub fn predict_task(&self, task_name: &str, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
441 let task_data =
442 self._state.tasks.get(task_name).ok_or_else(|| {
443 SklearsError::InvalidInput(format!("Task '{}' not found", task_name))
444 })?;
445
446 let task_X_train = &task_data.0;
447 let n_test = X.nrows();
448
449 let K_shared_star = self
451 ._state
452 .shared_kernel
453 .compute_kernel_matrix(&self._state.all_X, Some(&X.to_owned()))?;
454 let K_task_star = self
455 ._state
456 .task_kernel
457 .compute_kernel_matrix(task_X_train, Some(&X.to_owned()))?;
458
459 let mut predictions = Array1::<f64>::zeros(n_test);
461
462 for i in 0..n_test {
463 let mut pred = 0.0;
464
465 for j in 0..self._state.all_X.nrows() {
467 pred += self.shared_weight
468 * self.shared_weight
469 * K_shared_star[[j, i]]
470 * self._state.alpha_vector[j];
471 }
472
473 if let Some((start_idx, _end_idx)) = self._state.task_indices.get(task_name) {
475 for j in 0..task_X_train.nrows() {
476 let global_j = start_idx + j;
477 pred += self.task_weight
478 * self.task_weight
479 * K_task_star[[j, i]]
480 * self._state.alpha_vector[global_j];
481 }
482 }
483
484 predictions[i] = pred;
485 }
486
487 Ok(predictions)
488 }
489
490 #[allow(non_snake_case)]
492 pub fn predict_task_components(
493 &self,
494 task_name: &str,
495 X: &ArrayView2<f64>,
496 ) -> SklResult<(Array1<f64>, Array1<f64>)> {
497 let task_data =
498 self._state.tasks.get(task_name).ok_or_else(|| {
499 SklearsError::InvalidInput(format!("Task '{}' not found", task_name))
500 })?;
501
502 let task_X_train = &task_data.0;
503 let n_test = X.nrows();
504
505 let K_shared_star = self
507 ._state
508 .shared_kernel
509 .compute_kernel_matrix(&self._state.all_X, Some(&X.to_owned()))?;
510 let K_task_star = self
511 ._state
512 .task_kernel
513 .compute_kernel_matrix(task_X_train, Some(&X.to_owned()))?;
514
515 let mut shared_predictions = Array1::<f64>::zeros(n_test);
516 let mut task_predictions = Array1::<f64>::zeros(n_test);
517
518 for i in 0..n_test {
519 for j in 0..self._state.all_X.nrows() {
521 shared_predictions[i] +=
522 self.shared_weight * K_shared_star[[j, i]] * self._state.alpha_vector[j];
523 }
524
525 if let Some((start_idx, _)) = self._state.task_indices.get(task_name) {
527 for j in 0..task_X_train.nrows() {
528 let global_j = start_idx + j;
529 task_predictions[i] +=
530 self.task_weight * K_task_star[[j, i]] * self._state.alpha_vector[global_j];
531 }
532 }
533 }
534
535 Ok((shared_predictions, task_predictions))
536 }
537}
538
539#[allow(non_snake_case)]
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::kernels::RBF;
544
545 use scirs2_core::ndarray::array;
547
548 #[test]
549 fn test_mtgp_creation() {
550 let shared_kernel = RBF::new(1.0);
551 let task_kernel = RBF::new(0.5);
552 let mtgp = MultiTaskGaussianProcessRegressor::new()
553 .shared_kernel(Box::new(shared_kernel))
554 .task_kernel(Box::new(task_kernel))
555 .alpha(1e-6);
556
557 assert_eq!(mtgp.alpha, 1e-6);
558 assert_eq!(mtgp.shared_weight, 1.0);
559 assert_eq!(mtgp.task_weight, 1.0);
560 assert_eq!(mtgp.tasks.len(), 0);
561 }
562
563 #[test]
564 #[allow(non_snake_case)]
565 fn test_mtgp_add_task() {
566 let X = array![[1.0], [2.0], [3.0], [4.0]];
567 let y = array![1.0, 4.0, 9.0, 16.0];
568
569 let shared_kernel = RBF::new(1.0);
570 let task_kernel = RBF::new(0.5);
571 let mtgp = MultiTaskGaussianProcessRegressor::new()
572 .shared_kernel(Box::new(shared_kernel))
573 .task_kernel(Box::new(task_kernel))
574 .add_task("task1", &X.view(), &y.view())
575 .unwrap();
576
577 assert_eq!(mtgp.tasks.len(), 1);
578 assert!(mtgp.tasks.contains_key("task1"));
579 let task_names = mtgp.task_names();
580 assert!(task_names.contains(&"task1".to_string()));
581 }
582
583 #[test]
584 #[allow(non_snake_case)]
585 fn test_mtgp_remove_task() {
586 let X = array![[1.0], [2.0], [3.0], [4.0]];
587 let y = array![1.0, 4.0, 9.0, 16.0];
588
589 let shared_kernel = RBF::new(1.0);
590 let task_kernel = RBF::new(0.5);
591 let mtgp = MultiTaskGaussianProcessRegressor::new()
592 .shared_kernel(Box::new(shared_kernel))
593 .task_kernel(Box::new(task_kernel))
594 .add_task("task1", &X.view(), &y.view())
595 .unwrap()
596 .remove_task("task1");
597
598 assert_eq!(mtgp.tasks.len(), 0);
599 }
600
601 #[test]
602 #[allow(non_snake_case)]
603 fn test_mtgp_fit_single_task() {
604 let X = array![[1.0], [2.0], [3.0], [4.0]];
605 let y = array![1.0, 4.0, 9.0, 16.0];
606
607 let shared_kernel = RBF::new(1.0);
608 let task_kernel = RBF::new(0.5);
609 let mtgp = MultiTaskGaussianProcessRegressor::new()
610 .shared_kernel(Box::new(shared_kernel))
611 .task_kernel(Box::new(task_kernel))
612 .add_task("task1", &X.view(), &y.view())
613 .unwrap();
614
615 let fitted = mtgp.fit().unwrap();
616 assert_eq!(fitted.task_names().len(), 1);
617 assert!(fitted.log_marginal_likelihood_task("task1").is_some());
618 }
619
620 #[test]
621 fn test_mtgp_fit_multiple_tasks() {
622 let X1 = array![[1.0], [2.0], [3.0], [4.0]];
623 let y1 = array![1.0, 4.0, 9.0, 16.0];
624 let X2 = array![[1.5], [2.5], [3.5], [4.5]];
625 let y2 = array![2.0, 6.0, 12.0, 20.0];
626
627 let shared_kernel = RBF::new(1.0);
628 let task_kernel = RBF::new(0.5);
629 let mtgp = MultiTaskGaussianProcessRegressor::new()
630 .shared_kernel(Box::new(shared_kernel))
631 .task_kernel(Box::new(task_kernel))
632 .add_task("task1", &X1.view(), &y1.view())
633 .unwrap()
634 .add_task("task2", &X2.view(), &y2.view())
635 .unwrap();
636
637 let fitted = mtgp.fit().unwrap();
638 assert_eq!(fitted.task_names().len(), 2);
639 assert!(fitted.log_marginal_likelihood_task("task1").is_some());
640 assert!(fitted.log_marginal_likelihood_task("task2").is_some());
641 }
642
643 #[test]
644 fn test_mtgp_predict_task() {
645 let X1 = array![[1.0], [2.0], [3.0], [4.0]];
646 let y1 = array![1.0, 4.0, 9.0, 16.0];
647 let X2 = array![[1.5], [2.5], [3.5], [4.5]];
648 let y2 = array![2.0, 6.0, 12.0, 20.0];
649
650 let shared_kernel = RBF::new(1.0);
651 let task_kernel = RBF::new(0.5);
652 let mtgp = MultiTaskGaussianProcessRegressor::new()
653 .shared_kernel(Box::new(shared_kernel))
654 .task_kernel(Box::new(task_kernel))
655 .add_task("task1", &X1.view(), &y1.view())
656 .unwrap()
657 .add_task("task2", &X2.view(), &y2.view())
658 .unwrap();
659
660 let fitted = mtgp.fit().unwrap();
661
662 let predictions = fitted.predict_task("task1", &X1.view()).unwrap();
663 assert_eq!(predictions.len(), 4);
664
665 let predictions2 = fitted.predict_task("task2", &X2.view()).unwrap();
666 assert_eq!(predictions2.len(), 4);
667 }
668
669 #[test]
670 #[allow(non_snake_case)]
671 fn test_mtgp_predict_components() {
672 let X = array![[1.0], [2.0], [3.0], [4.0]];
673 let y = array![1.0, 4.0, 9.0, 16.0];
674
675 let shared_kernel = RBF::new(1.0);
676 let task_kernel = RBF::new(0.5);
677 let mtgp = MultiTaskGaussianProcessRegressor::new()
678 .shared_kernel(Box::new(shared_kernel))
679 .task_kernel(Box::new(task_kernel))
680 .add_task("task1", &X.view(), &y.view())
681 .unwrap();
682
683 let fitted = mtgp.fit().unwrap();
684 let (shared_pred, task_pred) = fitted.predict_task_components("task1", &X.view()).unwrap();
685
686 assert_eq!(shared_pred.len(), 4);
687 assert_eq!(task_pred.len(), 4);
688 }
689
690 #[test]
691 fn test_mtgp_log_marginal_likelihoods() {
692 let X1 = array![[1.0], [2.0], [3.0], [4.0]];
693 let y1 = array![1.0, 4.0, 9.0, 16.0];
694 let X2 = array![[1.5], [2.5], [3.5], [4.5]];
695 let y2 = array![2.0, 6.0, 12.0, 20.0];
696
697 let shared_kernel = RBF::new(1.0);
698 let task_kernel = RBF::new(0.5);
699 let mtgp = MultiTaskGaussianProcessRegressor::new()
700 .shared_kernel(Box::new(shared_kernel))
701 .task_kernel(Box::new(task_kernel))
702 .add_task("task1", &X1.view(), &y1.view())
703 .unwrap()
704 .add_task("task2", &X2.view(), &y2.view())
705 .unwrap();
706
707 let fitted = mtgp.fit().unwrap();
708 let all_lml = fitted.log_marginal_likelihoods();
709
710 assert_eq!(all_lml.len(), 2);
711 assert!(all_lml.contains_key("task1"));
712 assert!(all_lml.contains_key("task2"));
713 assert!(all_lml.get("task1").unwrap().is_finite());
714 assert!(all_lml.get("task2").unwrap().is_finite());
715 }
716
717 #[test]
718 #[allow(non_snake_case)]
719 fn test_mtgp_errors() {
720 let X = array![[1.0], [2.0], [3.0], [4.0]];
721 let y = array![1.0, 4.0, 9.0, 16.0];
722
723 let task_kernel = RBF::new(0.5);
725 let mtgp = MultiTaskGaussianProcessRegressor::new()
726 .task_kernel(Box::new(task_kernel))
727 .add_task("task1", &X.view(), &y.view())
728 .unwrap();
729 assert!(mtgp.fit().is_err());
730
731 let shared_kernel = RBF::new(1.0);
733 let mtgp = MultiTaskGaussianProcessRegressor::new()
734 .shared_kernel(Box::new(shared_kernel))
735 .add_task("task1", &X.view(), &y.view())
736 .unwrap();
737 assert!(mtgp.fit().is_err());
738
739 let shared_kernel = RBF::new(1.0);
741 let task_kernel = RBF::new(0.5);
742 let mtgp = MultiTaskGaussianProcessRegressor::new()
743 .shared_kernel(Box::new(shared_kernel))
744 .task_kernel(Box::new(task_kernel));
745 assert!(mtgp.fit().is_err());
746
747 let shared_kernel = RBF::new(1.0);
749 let task_kernel = RBF::new(0.5);
750 let mtgp = MultiTaskGaussianProcessRegressor::new()
751 .shared_kernel(Box::new(shared_kernel))
752 .task_kernel(Box::new(task_kernel))
753 .add_task("task1", &X.view(), &y.view())
754 .unwrap();
755
756 let fitted = mtgp.fit().unwrap();
757 assert!(fitted.predict_task("nonexistent", &X.view()).is_err());
758 }
759
760 #[test]
761 fn test_mtgp_mismatched_dimensions() {
762 let X1 = array![[1.0], [2.0], [3.0], [4.0]];
763 let y_wrong = array![1.0, 4.0, 9.0]; let shared_kernel = RBF::new(1.0);
766 let task_kernel = RBF::new(0.5);
767 let mtgp = MultiTaskGaussianProcessRegressor::new()
768 .shared_kernel(Box::new(shared_kernel))
769 .task_kernel(Box::new(task_kernel));
770
771 assert!(mtgp.add_task("task1", &X1.view(), &y_wrong.view()).is_err());
772 }
773}