sklears_gaussian_process/
multi_task.rs

1//! Multi-Task Gaussian Process for learning multiple related tasks
2//!
3//! This module implements multi-task Gaussian processes that learn multiple related tasks
4//! simultaneously by sharing information across tasks. This is particularly useful when
5//! you have several related learning problems that can benefit from shared knowledge.
6//!
7//! # Mathematical Background
8//!
9//! The multi-task GP models each task t using both shared and task-specific components:
10//! f_t(x) = f_shared(x) + f_task_t(x)
11//!
12//! where:
13//! - f_shared(x) is a shared latent function common to all tasks
14//! - f_task_t(x) is a task-specific function unique to task t
15//! - Each component has its own kernel and hyperparameters
16
17// SciRS2 Policy - Use scirs2-autograd for ndarray types and array! macro
18use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19// SciRS2 Policy - Use scirs2-core for random operations
20
21use sklears_core::{
22    error::{Result as SklResult, SklearsError},
23    traits::{Estimator, Untrained},
24};
25use std::collections::HashMap;
26use std::f64::consts::PI;
27
28use crate::kernels::Kernel;
29use crate::utils;
30
31/// Configuration for Multi-Task Gaussian Process
32#[derive(Debug, Clone)]
33pub struct MtgpConfig {
34    /// shared_kernel_name
35    pub shared_kernel_name: String,
36    /// task_kernel_name
37    pub task_kernel_name: String,
38    /// alpha
39    pub alpha: f64,
40    /// shared_weight
41    pub shared_weight: f64,
42    /// task_weight
43    pub task_weight: f64,
44    /// random_state
45    pub random_state: Option<u64>,
46}
47
48impl Default for MtgpConfig {
49    fn default() -> Self {
50        Self {
51            shared_kernel_name: "RBF".to_string(),
52            task_kernel_name: "RBF".to_string(),
53            alpha: 1e-10,
54            shared_weight: 1.0,
55            task_weight: 1.0,
56            random_state: None,
57        }
58    }
59}
60
61/// Multi-Task Gaussian Process Regressor
62///
63/// This implementation allows learning multiple related tasks simultaneously by sharing
64/// information through a shared latent function while maintaining task-specific variations.
65///
66/// # Mathematical Background
67///
68/// For each task t, the model assumes:
69/// f_t(x) = w_shared * f_shared(x) + w_task * f_task_t(x)
70///
71/// where the covariance between tasks i and j at points x and x' is:
72/// cov[f_i(x), f_j(x')] = w_shared² * k_shared(x, x') + δ_{i,j} * w_task² * k_task(x, x')
73///
74/// # Examples
75///
76/// ```
77/// use sklears_gaussian_process::{MultiTaskGaussianProcessRegressor, kernels::RBF};
78/// use sklears_core::traits::{Fit, Predict};
79/// // SciRS2 Policy - Use scirs2-autograd for ndarray types and operations
80/// use scirs2_core::ndarray::array;
81///
82/// let X1 = array![[1.0], [2.0], [3.0], [4.0]];
83/// let y1 = array![1.0, 4.0, 9.0, 16.0];
84/// let X2 = array![[1.5], [2.5], [3.5], [4.5]];
85/// let y2 = array![2.0, 6.0, 12.0, 20.0];
86///
87/// let shared_kernel = RBF::new(1.0);
88/// let task_kernel = RBF::new(0.5);
89/// let mtgp = MultiTaskGaussianProcessRegressor::new()
90///     .shared_kernel(Box::new(shared_kernel))
91///     .task_kernel(Box::new(task_kernel))
92///     .alpha(1e-6);
93///
94/// let mut mtgp = mtgp.add_task("task1", &X1.view(), &y1.view()).unwrap();
95/// mtgp = mtgp.add_task("task2", &X2.view(), &y2.view()).unwrap();
96/// let fitted = mtgp.fit().unwrap();
97/// let predictions = fitted.predict_task("task1", &X1.view()).unwrap();
98/// ```
99#[derive(Debug, Clone)]
100pub struct MultiTaskGaussianProcessRegressor<S = Untrained> {
101    shared_kernel: Option<Box<dyn Kernel>>,
102    task_kernel: Option<Box<dyn Kernel>>,
103    tasks: HashMap<String, (Array2<f64>, Array1<f64>)>, // task_name -> (X, y)
104    alpha: f64,
105    shared_weight: f64,
106    task_weight: f64,
107    _state: S,
108}
109
110/// Trained state for Multi-Task Gaussian Process
111#[derive(Debug, Clone)]
112pub struct MtgpTrained {
113    tasks: HashMap<String, (Array2<f64>, Array1<f64>)>,
114    shared_kernel: Box<dyn Kernel>,
115    task_kernel: Box<dyn Kernel>,
116    alpha: f64,
117    shared_weight: f64,
118    task_weight: f64,
119    alpha_vector: Array1<f64>, // Solution to the linear system
120    log_marginal_likelihood_values: HashMap<String, f64>, // Per-task log marginal likelihood
121    task_indices: HashMap<String, (usize, usize)>, // task_name -> (start_idx, end_idx)
122    all_X: Array2<f64>,        // Combined input data
123    all_y: Array1<f64>,        // Combined target data
124}
125
126impl MultiTaskGaussianProcessRegressor<Untrained> {
127    /// Create a new Multi-Task Gaussian Process Regressor
128    pub fn new() -> Self {
129        Self {
130            shared_kernel: None,
131            task_kernel: None,
132            tasks: HashMap::new(),
133            alpha: 1e-10,
134            shared_weight: 1.0,
135            task_weight: 1.0,
136            _state: Untrained,
137        }
138    }
139
140    /// Set the shared kernel function
141    pub fn shared_kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
142        self.shared_kernel = Some(kernel);
143        self
144    }
145
146    /// Set the task-specific kernel function
147    pub fn task_kernel(mut self, kernel: Box<dyn Kernel>) -> Self {
148        self.task_kernel = Some(kernel);
149        self
150    }
151
152    /// Set the regularization parameter
153    pub fn alpha(mut self, alpha: f64) -> Self {
154        self.alpha = alpha;
155        self
156    }
157
158    /// Set the weight for the shared component
159    pub fn shared_weight(mut self, weight: f64) -> Self {
160        self.shared_weight = weight;
161        self
162    }
163
164    /// Set the weight for the task-specific components
165    pub fn task_weight(mut self, weight: f64) -> Self {
166        self.task_weight = weight;
167        self
168    }
169
170    /// Add a task with its training data
171    pub fn add_task(
172        mut self,
173        task_name: &str,
174        X: &ArrayView2<f64>,
175        y: &ArrayView1<f64>,
176    ) -> SklResult<Self> {
177        if X.nrows() != y.len() {
178            return Err(SklearsError::InvalidInput(
179                "X and y must have the same number of samples".to_string(),
180            ));
181        }
182
183        self.tasks
184            .insert(task_name.to_string(), (X.to_owned(), y.to_owned()));
185        Ok(self)
186    }
187
188    /// Remove a task
189    pub fn remove_task(mut self, task_name: &str) -> Self {
190        self.tasks.remove(task_name);
191        self
192    }
193
194    /// Get the list of task names
195    pub fn task_names(&self) -> Vec<String> {
196        self.tasks.keys().cloned().collect()
197    }
198
199    /// Combine all task data into single arrays
200    fn combine_task_data(
201        &self,
202    ) -> SklResult<(Array2<f64>, Array1<f64>, HashMap<String, (usize, usize)>)> {
203        if self.tasks.is_empty() {
204            return Err(SklearsError::InvalidInput(
205                "At least one task must be added".to_string(),
206            ));
207        }
208
209        let mut all_X_vec = Vec::new();
210        let mut all_y_vec = Vec::new();
211        let mut task_indices = HashMap::new();
212        let mut current_idx = 0;
213
214        // Determine input dimension from first task
215        let first_task = self.tasks.values().next().unwrap();
216        let n_features = first_task.0.ncols();
217
218        for (task_name, (X, y)) in &self.tasks {
219            if X.ncols() != n_features {
220                return Err(SklearsError::InvalidInput(
221                    "All tasks must have the same number of features".to_string(),
222                ));
223            }
224
225            let n_samples = X.nrows();
226            task_indices.insert(task_name.clone(), (current_idx, current_idx + n_samples));
227
228            // Append X data
229            for i in 0..n_samples {
230                let mut row = Vec::new();
231                for j in 0..n_features {
232                    row.push(X[[i, j]]);
233                }
234                all_X_vec.push(row);
235            }
236
237            // Append y data
238            for i in 0..n_samples {
239                all_y_vec.push(y[i]);
240            }
241
242            current_idx += n_samples;
243        }
244
245        let n_total = all_X_vec.len();
246        let mut all_X = Array2::<f64>::zeros((n_total, n_features));
247        let mut all_y = Array1::<f64>::zeros(n_total);
248
249        for (i, row) in all_X_vec.iter().enumerate() {
250            for (j, &val) in row.iter().enumerate() {
251                all_X[[i, j]] = val;
252            }
253        }
254
255        for (i, &val) in all_y_vec.iter().enumerate() {
256            all_y[i] = val;
257        }
258
259        Ok((all_X, all_y, task_indices))
260    }
261
262    /// Compute the multi-task covariance matrix
263    #[allow(non_snake_case)]
264    fn compute_multitask_covariance(
265        &self,
266        X: &Array2<f64>,
267        task_indices: &HashMap<String, (usize, usize)>,
268        shared_kernel: &Box<dyn Kernel>,
269        task_kernel: &Box<dyn Kernel>,
270    ) -> SklResult<Array2<f64>> {
271        let n = X.nrows();
272        let mut K = Array2::<f64>::zeros((n, n));
273
274        // Compute shared kernel matrix (applies to all pairs)
275        let K_shared = shared_kernel.compute_kernel_matrix(X, None)?;
276
277        // Add shared component
278        for i in 0..n {
279            for j in 0..n {
280                K[[i, j]] += self.shared_weight * self.shared_weight * K_shared[[i, j]];
281            }
282        }
283
284        // Add task-specific components (only within same task)
285        let K_task = task_kernel.compute_kernel_matrix(X, None)?;
286
287        for (start_i, end_i) in task_indices.values() {
288            for i in *start_i..*end_i {
289                for j in *start_i..*end_i {
290                    K[[i, j]] += self.task_weight * self.task_weight * K_task[[i, j]];
291                }
292            }
293        }
294
295        Ok(K)
296    }
297}
298
299impl Default for MultiTaskGaussianProcessRegressor<Untrained> {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl Estimator for MultiTaskGaussianProcessRegressor<Untrained> {
306    type Config = MtgpConfig;
307    type Error = SklearsError;
308    type Float = f64;
309
310    fn config(&self) -> &Self::Config {
311        static DEFAULT_CONFIG: MtgpConfig = MtgpConfig {
312            shared_kernel_name: String::new(),
313            task_kernel_name: String::new(),
314            alpha: 1e-10,
315            shared_weight: 1.0,
316            task_weight: 1.0,
317            random_state: None,
318        };
319        &DEFAULT_CONFIG
320    }
321}
322
323impl Estimator for MultiTaskGaussianProcessRegressor<MtgpTrained> {
324    type Config = MtgpConfig;
325    type Error = SklearsError;
326    type Float = f64;
327
328    fn config(&self) -> &Self::Config {
329        static DEFAULT_CONFIG: MtgpConfig = MtgpConfig {
330            shared_kernel_name: String::new(),
331            task_kernel_name: String::new(),
332            alpha: 1e-10,
333            shared_weight: 1.0,
334            task_weight: 1.0,
335            random_state: None,
336        };
337        &DEFAULT_CONFIG
338    }
339}
340
341// Implement fit without requiring X and y parameters since tasks are added separately
342impl MultiTaskGaussianProcessRegressor<Untrained> {
343    /// Fit the multi-task Gaussian process
344    #[allow(non_snake_case)]
345    pub fn fit(self) -> SklResult<MultiTaskGaussianProcessRegressor<MtgpTrained>> {
346        let shared_kernel = self.shared_kernel.as_ref().ok_or_else(|| {
347            SklearsError::InvalidInput("Shared kernel must be specified".to_string())
348        })?;
349
350        let task_kernel = self.task_kernel.as_ref().ok_or_else(|| {
351            SklearsError::InvalidInput("Task kernel must be specified".to_string())
352        })?;
353
354        if self.tasks.is_empty() {
355            return Err(SklearsError::InvalidInput(
356                "At least one task must be added".to_string(),
357            ));
358        }
359
360        // Combine all task data
361        let (all_X, all_y, task_indices) = self.combine_task_data()?;
362
363        // Compute multi-task covariance matrix
364        let K =
365            self.compute_multitask_covariance(&all_X, &task_indices, shared_kernel, task_kernel)?;
366
367        // Add regularization
368        let mut K_reg = K.clone();
369        for i in 0..K_reg.nrows() {
370            K_reg[[i, i]] += self.alpha;
371        }
372
373        // Solve the linear system
374        let chol_decomp = utils::robust_cholesky(&K_reg)?;
375        let alpha_vector = utils::triangular_solve(&chol_decomp, &all_y)?;
376
377        // Compute per-task log marginal likelihood
378        let mut log_marginal_likelihood_values = HashMap::new();
379        for (task_name, (start_idx, end_idx)) in &task_indices {
380            let task_size = end_idx - start_idx;
381            let task_y = all_y.slice(scirs2_core::ndarray::s![*start_idx..*end_idx]);
382            let task_alpha = alpha_vector.slice(scirs2_core::ndarray::s![*start_idx..*end_idx]);
383
384            // Simplified log marginal likelihood for this task
385            let data_fit = task_y.dot(&task_alpha);
386            let log_ml = -0.5 * (data_fit + task_size as f64 * (2.0 * PI).ln());
387            log_marginal_likelihood_values.insert(task_name.clone(), log_ml);
388        }
389
390        Ok(MultiTaskGaussianProcessRegressor {
391            shared_kernel: None,
392            task_kernel: None,
393            tasks: self.tasks.clone(),
394            alpha: self.alpha,
395            shared_weight: self.shared_weight,
396            task_weight: self.task_weight,
397            _state: MtgpTrained {
398                tasks: self.tasks,
399                shared_kernel: shared_kernel.clone(),
400                task_kernel: task_kernel.clone(),
401                alpha: self.alpha,
402                shared_weight: self.shared_weight,
403                task_weight: self.task_weight,
404                alpha_vector,
405                log_marginal_likelihood_values,
406                task_indices,
407                all_X,
408                all_y,
409            },
410        })
411    }
412}
413
414impl MultiTaskGaussianProcessRegressor<MtgpTrained> {
415    /// Access the trained state
416    pub fn trained_state(&self) -> &MtgpTrained {
417        &self._state
418    }
419
420    /// Get the log marginal likelihood for a specific task
421    pub fn log_marginal_likelihood_task(&self, task_name: &str) -> Option<f64> {
422        self._state
423            .log_marginal_likelihood_values
424            .get(task_name)
425            .copied()
426    }
427
428    /// Get all log marginal likelihoods
429    pub fn log_marginal_likelihoods(&self) -> &HashMap<String, f64> {
430        &self._state.log_marginal_likelihood_values
431    }
432
433    /// Get the list of available tasks
434    pub fn task_names(&self) -> Vec<String> {
435        self._state.tasks.keys().cloned().collect()
436    }
437
438    /// Predict for a specific task
439    #[allow(non_snake_case)]
440    pub fn predict_task(&self, task_name: &str, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
441        let task_data =
442            self._state.tasks.get(task_name).ok_or_else(|| {
443                SklearsError::InvalidInput(format!("Task '{}' not found", task_name))
444            })?;
445
446        let task_X_train = &task_data.0;
447        let n_test = X.nrows();
448
449        // Compute cross-covariance between test points and all training points
450        let K_shared_star = self
451            ._state
452            .shared_kernel
453            .compute_kernel_matrix(&self._state.all_X, Some(&X.to_owned()))?;
454        let K_task_star = self
455            ._state
456            .task_kernel
457            .compute_kernel_matrix(task_X_train, Some(&X.to_owned()))?;
458
459        // For each test point, compute prediction
460        let mut predictions = Array1::<f64>::zeros(n_test);
461
462        for i in 0..n_test {
463            let mut pred = 0.0;
464
465            // Add shared component contribution from all tasks
466            for j in 0..self._state.all_X.nrows() {
467                pred += self.shared_weight
468                    * self.shared_weight
469                    * K_shared_star[[j, i]]
470                    * self._state.alpha_vector[j];
471            }
472
473            // Add task-specific contribution only from the same task
474            if let Some((start_idx, _end_idx)) = self._state.task_indices.get(task_name) {
475                for j in 0..task_X_train.nrows() {
476                    let global_j = start_idx + j;
477                    pred += self.task_weight
478                        * self.task_weight
479                        * K_task_star[[j, i]]
480                        * self._state.alpha_vector[global_j];
481                }
482            }
483
484            predictions[i] = pred;
485        }
486
487        Ok(predictions)
488    }
489
490    /// Get shared and task-specific contributions separately
491    #[allow(non_snake_case)]
492    pub fn predict_task_components(
493        &self,
494        task_name: &str,
495        X: &ArrayView2<f64>,
496    ) -> SklResult<(Array1<f64>, Array1<f64>)> {
497        let task_data =
498            self._state.tasks.get(task_name).ok_or_else(|| {
499                SklearsError::InvalidInput(format!("Task '{}' not found", task_name))
500            })?;
501
502        let task_X_train = &task_data.0;
503        let n_test = X.nrows();
504
505        // Compute cross-covariances
506        let K_shared_star = self
507            ._state
508            .shared_kernel
509            .compute_kernel_matrix(&self._state.all_X, Some(&X.to_owned()))?;
510        let K_task_star = self
511            ._state
512            .task_kernel
513            .compute_kernel_matrix(task_X_train, Some(&X.to_owned()))?;
514
515        let mut shared_predictions = Array1::<f64>::zeros(n_test);
516        let mut task_predictions = Array1::<f64>::zeros(n_test);
517
518        for i in 0..n_test {
519            // Shared component
520            for j in 0..self._state.all_X.nrows() {
521                shared_predictions[i] +=
522                    self.shared_weight * K_shared_star[[j, i]] * self._state.alpha_vector[j];
523            }
524
525            // Task-specific component
526            if let Some((start_idx, _)) = self._state.task_indices.get(task_name) {
527                for j in 0..task_X_train.nrows() {
528                    let global_j = start_idx + j;
529                    task_predictions[i] +=
530                        self.task_weight * K_task_star[[j, i]] * self._state.alpha_vector[global_j];
531                }
532            }
533        }
534
535        Ok((shared_predictions, task_predictions))
536    }
537}
538
539#[allow(non_snake_case)]
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use crate::kernels::RBF;
544
545    // SciRS2 Policy - Use scirs2-autograd for array! macro and types
546    use scirs2_core::ndarray::array;
547
548    #[test]
549    fn test_mtgp_creation() {
550        let shared_kernel = RBF::new(1.0);
551        let task_kernel = RBF::new(0.5);
552        let mtgp = MultiTaskGaussianProcessRegressor::new()
553            .shared_kernel(Box::new(shared_kernel))
554            .task_kernel(Box::new(task_kernel))
555            .alpha(1e-6);
556
557        assert_eq!(mtgp.alpha, 1e-6);
558        assert_eq!(mtgp.shared_weight, 1.0);
559        assert_eq!(mtgp.task_weight, 1.0);
560        assert_eq!(mtgp.tasks.len(), 0);
561    }
562
563    #[test]
564    #[allow(non_snake_case)]
565    fn test_mtgp_add_task() {
566        let X = array![[1.0], [2.0], [3.0], [4.0]];
567        let y = array![1.0, 4.0, 9.0, 16.0];
568
569        let shared_kernel = RBF::new(1.0);
570        let task_kernel = RBF::new(0.5);
571        let mtgp = MultiTaskGaussianProcessRegressor::new()
572            .shared_kernel(Box::new(shared_kernel))
573            .task_kernel(Box::new(task_kernel))
574            .add_task("task1", &X.view(), &y.view())
575            .unwrap();
576
577        assert_eq!(mtgp.tasks.len(), 1);
578        assert!(mtgp.tasks.contains_key("task1"));
579        let task_names = mtgp.task_names();
580        assert!(task_names.contains(&"task1".to_string()));
581    }
582
583    #[test]
584    #[allow(non_snake_case)]
585    fn test_mtgp_remove_task() {
586        let X = array![[1.0], [2.0], [3.0], [4.0]];
587        let y = array![1.0, 4.0, 9.0, 16.0];
588
589        let shared_kernel = RBF::new(1.0);
590        let task_kernel = RBF::new(0.5);
591        let mtgp = MultiTaskGaussianProcessRegressor::new()
592            .shared_kernel(Box::new(shared_kernel))
593            .task_kernel(Box::new(task_kernel))
594            .add_task("task1", &X.view(), &y.view())
595            .unwrap()
596            .remove_task("task1");
597
598        assert_eq!(mtgp.tasks.len(), 0);
599    }
600
601    #[test]
602    #[allow(non_snake_case)]
603    fn test_mtgp_fit_single_task() {
604        let X = array![[1.0], [2.0], [3.0], [4.0]];
605        let y = array![1.0, 4.0, 9.0, 16.0];
606
607        let shared_kernel = RBF::new(1.0);
608        let task_kernel = RBF::new(0.5);
609        let mtgp = MultiTaskGaussianProcessRegressor::new()
610            .shared_kernel(Box::new(shared_kernel))
611            .task_kernel(Box::new(task_kernel))
612            .add_task("task1", &X.view(), &y.view())
613            .unwrap();
614
615        let fitted = mtgp.fit().unwrap();
616        assert_eq!(fitted.task_names().len(), 1);
617        assert!(fitted.log_marginal_likelihood_task("task1").is_some());
618    }
619
620    #[test]
621    fn test_mtgp_fit_multiple_tasks() {
622        let X1 = array![[1.0], [2.0], [3.0], [4.0]];
623        let y1 = array![1.0, 4.0, 9.0, 16.0];
624        let X2 = array![[1.5], [2.5], [3.5], [4.5]];
625        let y2 = array![2.0, 6.0, 12.0, 20.0];
626
627        let shared_kernel = RBF::new(1.0);
628        let task_kernel = RBF::new(0.5);
629        let mtgp = MultiTaskGaussianProcessRegressor::new()
630            .shared_kernel(Box::new(shared_kernel))
631            .task_kernel(Box::new(task_kernel))
632            .add_task("task1", &X1.view(), &y1.view())
633            .unwrap()
634            .add_task("task2", &X2.view(), &y2.view())
635            .unwrap();
636
637        let fitted = mtgp.fit().unwrap();
638        assert_eq!(fitted.task_names().len(), 2);
639        assert!(fitted.log_marginal_likelihood_task("task1").is_some());
640        assert!(fitted.log_marginal_likelihood_task("task2").is_some());
641    }
642
643    #[test]
644    fn test_mtgp_predict_task() {
645        let X1 = array![[1.0], [2.0], [3.0], [4.0]];
646        let y1 = array![1.0, 4.0, 9.0, 16.0];
647        let X2 = array![[1.5], [2.5], [3.5], [4.5]];
648        let y2 = array![2.0, 6.0, 12.0, 20.0];
649
650        let shared_kernel = RBF::new(1.0);
651        let task_kernel = RBF::new(0.5);
652        let mtgp = MultiTaskGaussianProcessRegressor::new()
653            .shared_kernel(Box::new(shared_kernel))
654            .task_kernel(Box::new(task_kernel))
655            .add_task("task1", &X1.view(), &y1.view())
656            .unwrap()
657            .add_task("task2", &X2.view(), &y2.view())
658            .unwrap();
659
660        let fitted = mtgp.fit().unwrap();
661
662        let predictions = fitted.predict_task("task1", &X1.view()).unwrap();
663        assert_eq!(predictions.len(), 4);
664
665        let predictions2 = fitted.predict_task("task2", &X2.view()).unwrap();
666        assert_eq!(predictions2.len(), 4);
667    }
668
669    #[test]
670    #[allow(non_snake_case)]
671    fn test_mtgp_predict_components() {
672        let X = array![[1.0], [2.0], [3.0], [4.0]];
673        let y = array![1.0, 4.0, 9.0, 16.0];
674
675        let shared_kernel = RBF::new(1.0);
676        let task_kernel = RBF::new(0.5);
677        let mtgp = MultiTaskGaussianProcessRegressor::new()
678            .shared_kernel(Box::new(shared_kernel))
679            .task_kernel(Box::new(task_kernel))
680            .add_task("task1", &X.view(), &y.view())
681            .unwrap();
682
683        let fitted = mtgp.fit().unwrap();
684        let (shared_pred, task_pred) = fitted.predict_task_components("task1", &X.view()).unwrap();
685
686        assert_eq!(shared_pred.len(), 4);
687        assert_eq!(task_pred.len(), 4);
688    }
689
690    #[test]
691    fn test_mtgp_log_marginal_likelihoods() {
692        let X1 = array![[1.0], [2.0], [3.0], [4.0]];
693        let y1 = array![1.0, 4.0, 9.0, 16.0];
694        let X2 = array![[1.5], [2.5], [3.5], [4.5]];
695        let y2 = array![2.0, 6.0, 12.0, 20.0];
696
697        let shared_kernel = RBF::new(1.0);
698        let task_kernel = RBF::new(0.5);
699        let mtgp = MultiTaskGaussianProcessRegressor::new()
700            .shared_kernel(Box::new(shared_kernel))
701            .task_kernel(Box::new(task_kernel))
702            .add_task("task1", &X1.view(), &y1.view())
703            .unwrap()
704            .add_task("task2", &X2.view(), &y2.view())
705            .unwrap();
706
707        let fitted = mtgp.fit().unwrap();
708        let all_lml = fitted.log_marginal_likelihoods();
709
710        assert_eq!(all_lml.len(), 2);
711        assert!(all_lml.contains_key("task1"));
712        assert!(all_lml.contains_key("task2"));
713        assert!(all_lml.get("task1").unwrap().is_finite());
714        assert!(all_lml.get("task2").unwrap().is_finite());
715    }
716
717    #[test]
718    #[allow(non_snake_case)]
719    fn test_mtgp_errors() {
720        let X = array![[1.0], [2.0], [3.0], [4.0]];
721        let y = array![1.0, 4.0, 9.0, 16.0];
722
723        // Test with no shared kernel
724        let task_kernel = RBF::new(0.5);
725        let mtgp = MultiTaskGaussianProcessRegressor::new()
726            .task_kernel(Box::new(task_kernel))
727            .add_task("task1", &X.view(), &y.view())
728            .unwrap();
729        assert!(mtgp.fit().is_err());
730
731        // Test with no task kernel
732        let shared_kernel = RBF::new(1.0);
733        let mtgp = MultiTaskGaussianProcessRegressor::new()
734            .shared_kernel(Box::new(shared_kernel))
735            .add_task("task1", &X.view(), &y.view())
736            .unwrap();
737        assert!(mtgp.fit().is_err());
738
739        // Test with no tasks
740        let shared_kernel = RBF::new(1.0);
741        let task_kernel = RBF::new(0.5);
742        let mtgp = MultiTaskGaussianProcessRegressor::new()
743            .shared_kernel(Box::new(shared_kernel))
744            .task_kernel(Box::new(task_kernel));
745        assert!(mtgp.fit().is_err());
746
747        // Test prediction on non-existent task
748        let shared_kernel = RBF::new(1.0);
749        let task_kernel = RBF::new(0.5);
750        let mtgp = MultiTaskGaussianProcessRegressor::new()
751            .shared_kernel(Box::new(shared_kernel))
752            .task_kernel(Box::new(task_kernel))
753            .add_task("task1", &X.view(), &y.view())
754            .unwrap();
755
756        let fitted = mtgp.fit().unwrap();
757        assert!(fitted.predict_task("nonexistent", &X.view()).is_err());
758    }
759
760    #[test]
761    fn test_mtgp_mismatched_dimensions() {
762        let X1 = array![[1.0], [2.0], [3.0], [4.0]];
763        let y_wrong = array![1.0, 4.0, 9.0]; // Wrong size
764
765        let shared_kernel = RBF::new(1.0);
766        let task_kernel = RBF::new(0.5);
767        let mtgp = MultiTaskGaussianProcessRegressor::new()
768            .shared_kernel(Box::new(shared_kernel))
769            .task_kernel(Box::new(task_kernel));
770
771        assert!(mtgp.add_task("task1", &X1.view(), &y_wrong.view()).is_err());
772    }
773}