Skip to main content

tensorlogic_sklears_kernels/
multitask.rs

1//! Multi-task kernel learning for related tasks with shared representations.
2//!
3//! This module provides kernels that can learn multiple related tasks simultaneously,
4//! sharing information between tasks through structured covariance matrices.
5//!
6//! ## Features
7//!
8//! - **TaskKernel** - Wraps any kernel with task indices
9//! - **ICMKernel** - Intrinsic Coregionalization Model (B ⊗ K)
10//! - **LMCKernel** - Linear Model of Coregionalization (Σ B_q ⊗ K_q)
11//! - **IndexKernel** - Purely task-based similarity
12//!
13//! ## Use Cases
14//!
15//! - Multi-output regression
16//! - Transfer learning between related domains
17//! - Hierarchical task structures
18//! - Heterogeneous data fusion
19
20use crate::error::{KernelError, Result};
21use crate::types::Kernel;
22
23/// Input for multi-task kernels: features + task index.
24#[derive(Debug, Clone)]
25pub struct TaskInput {
26    /// Feature vector
27    pub features: Vec<f64>,
28    /// Task index (0-based)
29    pub task: usize,
30}
31
32impl TaskInput {
33    /// Create a new task input.
34    pub fn new(features: Vec<f64>, task: usize) -> Self {
35        Self { features, task }
36    }
37
38    /// Create from slice with task index.
39    pub fn from_slice(features: &[f64], task: usize) -> Self {
40        Self {
41            features: features.to_vec(),
42            task,
43        }
44    }
45}
46
47/// Configuration for multi-task kernels.
48#[derive(Debug, Clone)]
49pub struct MultiTaskConfig {
50    /// Number of tasks
51    pub num_tasks: usize,
52    /// Whether to normalize task covariance matrix
53    pub normalize: bool,
54}
55
56impl MultiTaskConfig {
57    /// Create configuration with specified number of tasks.
58    pub fn new(num_tasks: usize) -> Self {
59        Self {
60            num_tasks,
61            normalize: false,
62        }
63    }
64
65    /// Enable normalization.
66    pub fn with_normalization(mut self) -> Self {
67        self.normalize = true;
68        self
69    }
70}
71
72/// Index kernel: K(i, j) = B[i, j] where B is task covariance matrix.
73///
74/// Pure task-based similarity without feature component.
75/// Useful as a building block for more complex multi-task kernels.
76#[derive(Debug, Clone)]
77pub struct IndexKernel {
78    /// Task covariance matrix (num_tasks x num_tasks)
79    task_covariance: Vec<Vec<f64>>,
80    /// Number of tasks
81    num_tasks: usize,
82}
83
84impl IndexKernel {
85    /// Create an index kernel from task covariance matrix.
86    ///
87    /// The covariance matrix should be symmetric positive semi-definite.
88    pub fn new(task_covariance: Vec<Vec<f64>>) -> Result<Self> {
89        let num_tasks = task_covariance.len();
90        if num_tasks == 0 {
91            return Err(KernelError::InvalidParameter {
92                parameter: "task_covariance".to_string(),
93                value: "empty".to_string(),
94                reason: "must have at least one task".to_string(),
95            });
96        }
97
98        // Validate square matrix
99        for (i, row) in task_covariance.iter().enumerate() {
100            if row.len() != num_tasks {
101                return Err(KernelError::InvalidParameter {
102                    parameter: "task_covariance".to_string(),
103                    value: format!("row {} has {} elements", i, row.len()),
104                    reason: format!("expected {} elements (square matrix)", num_tasks),
105                });
106            }
107        }
108
109        Ok(Self {
110            task_covariance,
111            num_tasks,
112        })
113    }
114
115    /// Create with identity covariance (independent tasks).
116    pub fn identity(num_tasks: usize) -> Result<Self> {
117        let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
118        for (i, row) in cov.iter_mut().enumerate() {
119            row[i] = 1.0;
120        }
121        Self::new(cov)
122    }
123
124    /// Create with uniform covariance (all tasks equally similar).
125    pub fn uniform(num_tasks: usize, correlation: f64) -> Result<Self> {
126        if !(0.0..=1.0).contains(&correlation) {
127            return Err(KernelError::InvalidParameter {
128                parameter: "correlation".to_string(),
129                value: correlation.to_string(),
130                reason: "must be in [0, 1]".to_string(),
131            });
132        }
133
134        let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
135        for (i, row) in cov.iter_mut().enumerate() {
136            row[i] = 1.0;
137        }
138        Self::new(cov)
139    }
140
141    /// Get task covariance value.
142    pub fn get_task_covariance(&self, task_i: usize, task_j: usize) -> Result<f64> {
143        if task_i >= self.num_tasks || task_j >= self.num_tasks {
144            return Err(KernelError::ComputationError(format!(
145                "Task index out of bounds: ({}, {}) for {} tasks",
146                task_i, task_j, self.num_tasks
147            )));
148        }
149        Ok(self.task_covariance[task_i][task_j])
150    }
151
152    /// Get number of tasks.
153    pub fn num_tasks(&self) -> usize {
154        self.num_tasks
155    }
156
157    /// Get the full covariance matrix.
158    pub fn covariance_matrix(&self) -> &Vec<Vec<f64>> {
159        &self.task_covariance
160    }
161}
162
163/// Intrinsic Coregionalization Model (ICM) kernel.
164///
165/// K((x, i), (y, j)) = B[i, j] * k(x, y)
166///
167/// where:
168/// - B is the task covariance matrix (num_tasks x num_tasks)
169/// - k is the base kernel on features
170///
171/// This model assumes all tasks share the same underlying kernel
172/// but with different task-specific scales captured in B.
173pub struct ICMKernel {
174    /// Base kernel for features
175    base_kernel: Box<dyn Kernel>,
176    /// Task covariance/similarity matrix
177    task_covariance: Vec<Vec<f64>>,
178    /// Number of tasks
179    num_tasks: usize,
180}
181
182impl ICMKernel {
183    /// Create a new ICM kernel.
184    ///
185    /// # Arguments
186    /// * `base_kernel` - Kernel for feature similarity
187    /// * `task_covariance` - Positive semi-definite task covariance matrix
188    pub fn new(base_kernel: Box<dyn Kernel>, task_covariance: Vec<Vec<f64>>) -> Result<Self> {
189        let num_tasks = task_covariance.len();
190        if num_tasks == 0 {
191            return Err(KernelError::InvalidParameter {
192                parameter: "task_covariance".to_string(),
193                value: "empty".to_string(),
194                reason: "must have at least one task".to_string(),
195            });
196        }
197
198        // Validate square matrix
199        for (i, row) in task_covariance.iter().enumerate() {
200            if row.len() != num_tasks {
201                return Err(KernelError::InvalidParameter {
202                    parameter: "task_covariance".to_string(),
203                    value: format!("row {} has {} elements", i, row.len()),
204                    reason: format!("expected {} elements", num_tasks),
205                });
206            }
207        }
208
209        Ok(Self {
210            base_kernel,
211            task_covariance,
212            num_tasks,
213        })
214    }
215
216    /// Create ICM with identity task covariance (independent tasks).
217    pub fn independent(base_kernel: Box<dyn Kernel>, num_tasks: usize) -> Result<Self> {
218        let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
219        for (i, row) in cov.iter_mut().enumerate() {
220            row[i] = 1.0;
221        }
222        Self::new(base_kernel, cov)
223    }
224
225    /// Create ICM with uniform task correlation.
226    pub fn uniform(
227        base_kernel: Box<dyn Kernel>,
228        num_tasks: usize,
229        correlation: f64,
230    ) -> Result<Self> {
231        if !(0.0..=1.0).contains(&correlation) {
232            return Err(KernelError::InvalidParameter {
233                parameter: "correlation".to_string(),
234                value: correlation.to_string(),
235                reason: "must be in [0, 1]".to_string(),
236            });
237        }
238
239        let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
240        for (i, row) in cov.iter_mut().enumerate() {
241            row[i] = 1.0;
242        }
243        Self::new(base_kernel, cov)
244    }
245
246    /// Create ICM from rank-1 decomposition B = v * v^T.
247    ///
248    /// This is useful when you have task-specific variances.
249    pub fn from_rank1(base_kernel: Box<dyn Kernel>, task_variances: Vec<f64>) -> Result<Self> {
250        let num_tasks = task_variances.len();
251        let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
252        for i in 0..num_tasks {
253            for j in 0..num_tasks {
254                cov[i][j] = task_variances[i].sqrt() * task_variances[j].sqrt();
255            }
256        }
257        Self::new(base_kernel, cov)
258    }
259
260    /// Compute ICM kernel value for task inputs.
261    pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
262        if x.task >= self.num_tasks || y.task >= self.num_tasks {
263            return Err(KernelError::ComputationError(format!(
264                "Task index out of bounds: ({}, {}) for {} tasks",
265                x.task, y.task, self.num_tasks
266            )));
267        }
268
269        let k_features = self.base_kernel.compute(&x.features, &y.features)?;
270        let b_tasks = self.task_covariance[x.task][y.task];
271
272        Ok(b_tasks * k_features)
273    }
274
275    /// Get number of tasks.
276    pub fn num_tasks(&self) -> usize {
277        self.num_tasks
278    }
279
280    /// Get task covariance matrix.
281    pub fn task_covariance(&self) -> &Vec<Vec<f64>> {
282        &self.task_covariance
283    }
284
285    /// Compute full kernel matrix for multiple task inputs.
286    pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
287        let n = inputs.len();
288        let mut matrix = vec![vec![0.0; n]; n];
289
290        for i in 0..n {
291            for j in i..n {
292                let k = self.compute_tasks(&inputs[i], &inputs[j])?;
293                matrix[i][j] = k;
294                matrix[j][i] = k;
295            }
296        }
297
298        Ok(matrix)
299    }
300}
301
302/// A single latent process component for LMC.
303struct LMCComponent {
304    /// Base kernel for this component
305    kernel: Box<dyn Kernel>,
306    /// Task covariance matrix for this component
307    task_covariance: Vec<Vec<f64>>,
308}
309
310/// Linear Model of Coregionalization (LMC) kernel.
311///
312/// K((x, i), (y, j)) = Σ_q B_q[i, j] * k_q(x, y)
313///
314/// where:
315/// - Each (B_q, k_q) pair represents a latent process
316/// - B_q is a task covariance matrix
317/// - k_q is a kernel function
318///
319/// LMC is more expressive than ICM as it allows different
320/// kernels to capture different aspects of task relationships.
321pub struct LMCKernel {
322    /// Latent process components
323    components: Vec<LMCComponent>,
324    /// Number of tasks
325    num_tasks: usize,
326}
327
328impl LMCKernel {
329    /// Create a new LMC kernel.
330    pub fn new(num_tasks: usize) -> Self {
331        Self {
332            components: Vec::new(),
333            num_tasks,
334        }
335    }
336
337    /// Add a latent process component.
338    pub fn add_component(
339        &mut self,
340        kernel: Box<dyn Kernel>,
341        task_covariance: Vec<Vec<f64>>,
342    ) -> Result<()> {
343        // Validate task covariance dimensions
344        if task_covariance.len() != self.num_tasks {
345            return Err(KernelError::InvalidParameter {
346                parameter: "task_covariance".to_string(),
347                value: format!("{} rows", task_covariance.len()),
348                reason: format!("expected {} rows", self.num_tasks),
349            });
350        }
351
352        for (i, row) in task_covariance.iter().enumerate() {
353            if row.len() != self.num_tasks {
354                return Err(KernelError::InvalidParameter {
355                    parameter: "task_covariance".to_string(),
356                    value: format!("row {} has {} elements", i, row.len()),
357                    reason: format!("expected {} elements", self.num_tasks),
358                });
359            }
360        }
361
362        self.components.push(LMCComponent {
363            kernel,
364            task_covariance,
365        });
366
367        Ok(())
368    }
369
370    /// Compute LMC kernel value for task inputs.
371    pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
372        if x.task >= self.num_tasks || y.task >= self.num_tasks {
373            return Err(KernelError::ComputationError(format!(
374                "Task index out of bounds: ({}, {}) for {} tasks",
375                x.task, y.task, self.num_tasks
376            )));
377        }
378
379        let mut result = 0.0;
380        for component in &self.components {
381            let k_features = component.kernel.compute(&x.features, &y.features)?;
382            let b_tasks = component.task_covariance[x.task][y.task];
383            result += b_tasks * k_features;
384        }
385
386        Ok(result)
387    }
388
389    /// Get number of components.
390    pub fn num_components(&self) -> usize {
391        self.components.len()
392    }
393
394    /// Get number of tasks.
395    pub fn num_tasks(&self) -> usize {
396        self.num_tasks
397    }
398
399    /// Compute full kernel matrix for multiple task inputs.
400    pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
401        let n = inputs.len();
402        let mut matrix = vec![vec![0.0; n]; n];
403
404        for i in 0..n {
405            for j in i..n {
406                let k = self.compute_tasks(&inputs[i], &inputs[j])?;
407                matrix[i][j] = k;
408                matrix[j][i] = k;
409            }
410        }
411
412        Ok(matrix)
413    }
414}
415
416/// Wrapper to use ICM kernel with standard Kernel trait.
417///
418/// Encodes task index in the first element of the input vector.
419pub struct ICMKernelWrapper {
420    inner: ICMKernel,
421}
422
423impl ICMKernelWrapper {
424    /// Create wrapper from ICM kernel.
425    pub fn new(inner: ICMKernel) -> Self {
426        Self { inner }
427    }
428
429    /// Get the inner ICM kernel.
430    pub fn inner(&self) -> &ICMKernel {
431        &self.inner
432    }
433}
434
435impl Kernel for ICMKernelWrapper {
436    /// Compute kernel where first element is task index.
437    ///
438    /// Input format: [task_index, feature_1, feature_2, ...]
439    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
440        if x.is_empty() || y.is_empty() {
441            return Err(KernelError::ComputationError(
442                "Input must have at least task index".to_string(),
443            ));
444        }
445
446        let task_x = x[0] as usize;
447        let task_y = y[0] as usize;
448        let features_x = &x[1..];
449        let features_y = &y[1..];
450
451        let input_x = TaskInput::from_slice(features_x, task_x);
452        let input_y = TaskInput::from_slice(features_y, task_y);
453
454        self.inner.compute_tasks(&input_x, &input_y)
455    }
456
457    fn name(&self) -> &str {
458        "ICM"
459    }
460}
461
462/// Wrapper to use LMC kernel with standard Kernel trait.
463pub struct LMCKernelWrapper {
464    inner: LMCKernel,
465}
466
467impl LMCKernelWrapper {
468    /// Create wrapper from LMC kernel.
469    pub fn new(inner: LMCKernel) -> Self {
470        Self { inner }
471    }
472
473    /// Get the inner LMC kernel.
474    pub fn inner(&self) -> &LMCKernel {
475        &self.inner
476    }
477}
478
479impl Kernel for LMCKernelWrapper {
480    /// Compute kernel where first element is task index.
481    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
482        if x.is_empty() || y.is_empty() {
483            return Err(KernelError::ComputationError(
484                "Input must have at least task index".to_string(),
485            ));
486        }
487
488        let task_x = x[0] as usize;
489        let task_y = y[0] as usize;
490        let features_x = &x[1..];
491        let features_y = &y[1..];
492
493        let input_x = TaskInput::from_slice(features_x, task_x);
494        let input_y = TaskInput::from_slice(features_y, task_y);
495
496        self.inner.compute_tasks(&input_x, &input_y)
497    }
498
499    fn name(&self) -> &str {
500        "LMC"
501    }
502}
503
504/// Hadamard (element-wise) product of multiple task kernels.
505///
506/// K((x, i), (y, j)) = Π_q K_q((x, i), (y, j))
507///
508/// Useful for combining different aspects of task similarity.
509pub struct HadamardTaskKernel {
510    /// Component kernels
511    kernels: Vec<ICMKernel>,
512}
513
514impl HadamardTaskKernel {
515    /// Create a new Hadamard task kernel.
516    pub fn new() -> Self {
517        Self {
518            kernels: Vec::new(),
519        }
520    }
521
522    /// Add a component kernel.
523    pub fn add_kernel(&mut self, kernel: ICMKernel) -> Result<()> {
524        if !self.kernels.is_empty() && kernel.num_tasks() != self.kernels[0].num_tasks() {
525            return Err(KernelError::InvalidParameter {
526                parameter: "num_tasks".to_string(),
527                value: kernel.num_tasks().to_string(),
528                reason: format!("expected {}", self.kernels[0].num_tasks()),
529            });
530        }
531        self.kernels.push(kernel);
532        Ok(())
533    }
534
535    /// Compute kernel value.
536    pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
537        if self.kernels.is_empty() {
538            return Err(KernelError::ComputationError(
539                "No component kernels added".to_string(),
540            ));
541        }
542
543        let mut result = 1.0;
544        for kernel in &self.kernels {
545            result *= kernel.compute_tasks(x, y)?;
546        }
547        Ok(result)
548    }
549
550    /// Get number of tasks.
551    pub fn num_tasks(&self) -> Option<usize> {
552        self.kernels.first().map(|k| k.num_tasks())
553    }
554}
555
556impl Default for HadamardTaskKernel {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562/// Builder for creating multi-task kernels.
563pub struct MultiTaskKernelBuilder {
564    num_tasks: usize,
565    base_kernels: Vec<Box<dyn Kernel>>,
566    task_covariances: Vec<Vec<Vec<f64>>>,
567}
568
569impl MultiTaskKernelBuilder {
570    /// Create a new builder.
571    pub fn new(num_tasks: usize) -> Self {
572        Self {
573            num_tasks,
574            base_kernels: Vec::new(),
575            task_covariances: Vec::new(),
576        }
577    }
578
579    /// Add a component with its kernel and task covariance.
580    pub fn add_component(
581        mut self,
582        kernel: Box<dyn Kernel>,
583        task_covariance: Vec<Vec<f64>>,
584    ) -> Self {
585        self.base_kernels.push(kernel);
586        self.task_covariances.push(task_covariance);
587        self
588    }
589
590    /// Build an ICM kernel (single component).
591    pub fn build_icm(self) -> Result<ICMKernel> {
592        if self.base_kernels.len() != 1 {
593            return Err(KernelError::InvalidParameter {
594                parameter: "components".to_string(),
595                value: self.base_kernels.len().to_string(),
596                reason: "ICM requires exactly one component".to_string(),
597            });
598        }
599
600        let kernel = self.base_kernels.into_iter().next().unwrap();
601        let cov = self.task_covariances.into_iter().next().unwrap();
602        ICMKernel::new(kernel, cov)
603    }
604
605    /// Build an LMC kernel (multiple components).
606    pub fn build_lmc(self) -> Result<LMCKernel> {
607        let mut lmc = LMCKernel::new(self.num_tasks);
608
609        for (kernel, cov) in self.base_kernels.into_iter().zip(self.task_covariances) {
610            lmc.add_component(kernel, cov)?;
611        }
612
613        Ok(lmc)
614    }
615}
616
617#[cfg(test)]
618#[allow(clippy::needless_range_loop)]
619mod tests {
620    use super::*;
621    use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
622
623    // ===== IndexKernel Tests =====
624
625    #[test]
626    fn test_index_kernel_basic() {
627        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
628        let kernel = IndexKernel::new(cov).unwrap();
629
630        assert_eq!(kernel.num_tasks(), 2);
631        assert!((kernel.get_task_covariance(0, 1).unwrap() - 0.5).abs() < 1e-10);
632        assert!((kernel.get_task_covariance(1, 1).unwrap() - 1.0).abs() < 1e-10);
633    }
634
635    #[test]
636    fn test_index_kernel_identity() {
637        let kernel = IndexKernel::identity(3).unwrap();
638
639        assert!((kernel.get_task_covariance(0, 0).unwrap() - 1.0).abs() < 1e-10);
640        assert!((kernel.get_task_covariance(0, 1).unwrap()).abs() < 1e-10);
641        assert!((kernel.get_task_covariance(1, 2).unwrap()).abs() < 1e-10);
642    }
643
644    #[test]
645    fn test_index_kernel_uniform() {
646        let kernel = IndexKernel::uniform(3, 0.5).unwrap();
647
648        assert!((kernel.get_task_covariance(0, 0).unwrap() - 1.0).abs() < 1e-10);
649        assert!((kernel.get_task_covariance(0, 1).unwrap() - 0.5).abs() < 1e-10);
650        assert!((kernel.get_task_covariance(1, 2).unwrap() - 0.5).abs() < 1e-10);
651    }
652
653    #[test]
654    fn test_index_kernel_invalid() {
655        // Empty
656        let result = IndexKernel::new(vec![]);
657        assert!(result.is_err());
658
659        // Non-square
660        let result = IndexKernel::new(vec![vec![1.0, 0.5]]);
661        assert!(result.is_err());
662
663        // Invalid correlation
664        let result = IndexKernel::uniform(3, 1.5);
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_index_kernel_out_of_bounds() {
670        let kernel = IndexKernel::identity(2).unwrap();
671        assert!(kernel.get_task_covariance(2, 0).is_err());
672    }
673
674    // ===== ICMKernel Tests =====
675
676    #[test]
677    fn test_icm_kernel_basic() {
678        let base = LinearKernel::new();
679        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
680        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
681
682        assert_eq!(icm.num_tasks(), 2);
683    }
684
685    #[test]
686    fn test_icm_kernel_compute() {
687        let base = LinearKernel::new();
688        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
689        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
690
691        let x = TaskInput::new(vec![1.0, 2.0], 0);
692        let y = TaskInput::new(vec![3.0, 4.0], 1);
693
694        let k = icm.compute_tasks(&x, &y).unwrap();
695        // Linear: 1*3 + 2*4 = 11
696        // Task covariance: 0.5
697        // Result: 0.5 * 11 = 5.5
698        assert!((k - 5.5).abs() < 1e-10);
699    }
700
701    #[test]
702    fn test_icm_kernel_same_task() {
703        let base = LinearKernel::new();
704        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
705        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
706
707        let x = TaskInput::new(vec![1.0, 2.0], 0);
708        let y = TaskInput::new(vec![3.0, 4.0], 0);
709
710        let k = icm.compute_tasks(&x, &y).unwrap();
711        // Linear: 11, Task: 1.0, Result: 11.0
712        assert!((k - 11.0).abs() < 1e-10);
713    }
714
715    #[test]
716    fn test_icm_kernel_independent() {
717        let base = LinearKernel::new();
718        let icm = ICMKernel::independent(Box::new(base), 3).unwrap();
719
720        let x = TaskInput::new(vec![1.0], 0);
721        let y = TaskInput::new(vec![1.0], 1);
722
723        // Different tasks, independent => 0
724        let k = icm.compute_tasks(&x, &y).unwrap();
725        assert!(k.abs() < 1e-10);
726
727        // Same task => base kernel value
728        let z = TaskInput::new(vec![1.0], 0);
729        let k = icm.compute_tasks(&x, &z).unwrap();
730        assert!((k - 1.0).abs() < 1e-10);
731    }
732
733    #[test]
734    fn test_icm_kernel_uniform() {
735        let base = LinearKernel::new();
736        let icm = ICMKernel::uniform(Box::new(base), 2, 0.8).unwrap();
737
738        let x = TaskInput::new(vec![1.0], 0);
739        let y = TaskInput::new(vec![1.0], 1);
740
741        let k = icm.compute_tasks(&x, &y).unwrap();
742        // Linear: 1.0, Task: 0.8, Result: 0.8
743        assert!((k - 0.8).abs() < 1e-10);
744    }
745
746    #[test]
747    fn test_icm_kernel_rank1() {
748        let base = LinearKernel::new();
749        let variances = vec![1.0, 4.0]; // sqrt gives [1.0, 2.0]
750        let icm = ICMKernel::from_rank1(Box::new(base), variances).unwrap();
751
752        // B[0,1] = sqrt(1) * sqrt(4) = 2.0
753        let x = TaskInput::new(vec![1.0], 0);
754        let y = TaskInput::new(vec![1.0], 1);
755
756        let k = icm.compute_tasks(&x, &y).unwrap();
757        assert!((k - 2.0).abs() < 1e-10);
758    }
759
760    #[test]
761    fn test_icm_kernel_matrix() {
762        let base = LinearKernel::new();
763        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
764        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
765
766        let inputs = vec![
767            TaskInput::new(vec![1.0], 0),
768            TaskInput::new(vec![1.0], 1),
769            TaskInput::new(vec![2.0], 0),
770        ];
771
772        let matrix = icm.compute_task_matrix(&inputs).unwrap();
773
774        assert_eq!(matrix.len(), 3);
775        // Symmetric
776        for i in 0..3 {
777            for j in 0..3 {
778                assert!(
779                    (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
780                    "Matrix not symmetric at ({}, {})",
781                    i,
782                    j
783                );
784            }
785        }
786    }
787
788    #[test]
789    fn test_icm_kernel_invalid_task() {
790        let base = LinearKernel::new();
791        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
792        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
793
794        let x = TaskInput::new(vec![1.0], 0);
795        let y = TaskInput::new(vec![1.0], 5); // Out of bounds
796
797        assert!(icm.compute_tasks(&x, &y).is_err());
798    }
799
800    // ===== LMCKernel Tests =====
801
802    #[test]
803    fn test_lmc_kernel_basic() {
804        let mut lmc = LMCKernel::new(2);
805
806        let base1 = LinearKernel::new();
807        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
808        lmc.add_component(Box::new(base1), cov1).unwrap();
809
810        assert_eq!(lmc.num_tasks(), 2);
811        assert_eq!(lmc.num_components(), 1);
812    }
813
814    #[test]
815    fn test_lmc_kernel_compute() {
816        let mut lmc = LMCKernel::new(2);
817
818        // Component 1: Linear with correlation
819        let base1 = LinearKernel::new();
820        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
821        lmc.add_component(Box::new(base1), cov1).unwrap();
822
823        // Component 2: RBF with different correlation
824        let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).unwrap();
825        let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
826        lmc.add_component(Box::new(base2), cov2).unwrap();
827
828        let x = TaskInput::new(vec![1.0, 0.0], 0);
829        let y = TaskInput::new(vec![1.0, 0.0], 1);
830
831        let k = lmc.compute_tasks(&x, &y).unwrap();
832        // Linear: 1.0, Task cov1: 0.5 => 0.5
833        // RBF: 1.0, Task cov2: 1.0 => 1.0
834        // Sum: 1.5
835        assert!((k - 1.5).abs() < 1e-10);
836    }
837
838    #[test]
839    fn test_lmc_kernel_matrix() {
840        let mut lmc = LMCKernel::new(2);
841
842        let base = LinearKernel::new();
843        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
844        lmc.add_component(Box::new(base), cov).unwrap();
845
846        let inputs = vec![TaskInput::new(vec![1.0], 0), TaskInput::new(vec![1.0], 1)];
847
848        let matrix = lmc.compute_task_matrix(&inputs).unwrap();
849        assert_eq!(matrix.len(), 2);
850
851        // Check symmetry
852        assert!((matrix[0][1] - matrix[1][0]).abs() < 1e-10);
853    }
854
855    #[test]
856    fn test_lmc_kernel_invalid_dimensions() {
857        let mut lmc = LMCKernel::new(2);
858
859        let base = LinearKernel::new();
860        let cov = vec![
861            vec![1.0, 0.5, 0.3],
862            vec![0.5, 1.0, 0.4],
863            vec![0.3, 0.4, 1.0],
864        ];
865
866        // Wrong number of tasks
867        assert!(lmc.add_component(Box::new(base), cov).is_err());
868    }
869
870    // ===== Wrapper Tests =====
871
872    #[test]
873    fn test_icm_wrapper() {
874        let base = LinearKernel::new();
875        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
876        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
877        let wrapper = ICMKernelWrapper::new(icm);
878
879        // [task, features...]
880        let x = vec![0.0, 1.0, 2.0]; // Task 0
881        let y = vec![1.0, 3.0, 4.0]; // Task 1
882
883        let k = wrapper.compute(&x, &y).unwrap();
884        // Linear: 11, Task: 0.5, Result: 5.5
885        assert!((k - 5.5).abs() < 1e-10);
886
887        assert_eq!(wrapper.name(), "ICM");
888    }
889
890    #[test]
891    fn test_lmc_wrapper() {
892        let mut lmc = LMCKernel::new(2);
893        let base = LinearKernel::new();
894        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
895        lmc.add_component(Box::new(base), cov).unwrap();
896
897        let wrapper = LMCKernelWrapper::new(lmc);
898
899        let x = vec![0.0, 1.0]; // Task 0
900        let y = vec![1.0, 1.0]; // Task 1
901
902        let k = wrapper.compute(&x, &y).unwrap();
903        assert!((k - 0.5).abs() < 1e-10);
904
905        assert_eq!(wrapper.name(), "LMC");
906    }
907
908    #[test]
909    fn test_wrapper_empty_input() {
910        let base = LinearKernel::new();
911        let cov = vec![vec![1.0]];
912        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
913        let wrapper = ICMKernelWrapper::new(icm);
914
915        assert!(wrapper.compute(&[], &[0.0, 1.0]).is_err());
916    }
917
918    // ===== HadamardTaskKernel Tests =====
919
920    #[test]
921    fn test_hadamard_task_kernel() {
922        let mut hadamard = HadamardTaskKernel::new();
923
924        let base1 = LinearKernel::new();
925        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
926        let icm1 = ICMKernel::new(Box::new(base1), cov1).unwrap();
927        hadamard.add_kernel(icm1).unwrap();
928
929        let base2 = LinearKernel::new();
930        let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
931        let icm2 = ICMKernel::new(Box::new(base2), cov2).unwrap();
932        hadamard.add_kernel(icm2).unwrap();
933
934        let x = TaskInput::new(vec![1.0], 0);
935        let y = TaskInput::new(vec![1.0], 1);
936
937        let k = hadamard.compute_tasks(&x, &y).unwrap();
938        // ICM1: 0.5 * 1.0 = 0.5
939        // ICM2: 1.0 * 1.0 = 1.0
940        // Product: 0.5
941        assert!((k - 0.5).abs() < 1e-10);
942    }
943
944    #[test]
945    fn test_hadamard_task_kernel_empty() {
946        let hadamard = HadamardTaskKernel::new();
947        let x = TaskInput::new(vec![1.0], 0);
948        let y = TaskInput::new(vec![1.0], 0);
949
950        assert!(hadamard.compute_tasks(&x, &y).is_err());
951    }
952
953    #[test]
954    fn test_hadamard_mismatched_tasks() {
955        let mut hadamard = HadamardTaskKernel::new();
956
957        let base1 = LinearKernel::new();
958        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
959        let icm1 = ICMKernel::new(Box::new(base1), cov1).unwrap();
960        hadamard.add_kernel(icm1).unwrap();
961
962        // Different number of tasks
963        let base2 = LinearKernel::new();
964        let cov2 = vec![
965            vec![1.0, 0.5, 0.3],
966            vec![0.5, 1.0, 0.4],
967            vec![0.3, 0.4, 1.0],
968        ];
969        let icm2 = ICMKernel::new(Box::new(base2), cov2).unwrap();
970
971        assert!(hadamard.add_kernel(icm2).is_err());
972    }
973
974    // ===== Builder Tests =====
975
976    #[test]
977    fn test_builder_icm() {
978        let base = LinearKernel::new();
979        let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
980
981        let icm = MultiTaskKernelBuilder::new(2)
982            .add_component(Box::new(base), cov)
983            .build_icm()
984            .unwrap();
985
986        assert_eq!(icm.num_tasks(), 2);
987    }
988
989    #[test]
990    fn test_builder_lmc() {
991        let base1 = LinearKernel::new();
992        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
993
994        let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).unwrap();
995        let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
996
997        let lmc = MultiTaskKernelBuilder::new(2)
998            .add_component(Box::new(base1), cov1)
999            .add_component(Box::new(base2), cov2)
1000            .build_lmc()
1001            .unwrap();
1002
1003        assert_eq!(lmc.num_tasks(), 2);
1004        assert_eq!(lmc.num_components(), 2);
1005    }
1006
1007    #[test]
1008    fn test_builder_icm_wrong_components() {
1009        let base1 = LinearKernel::new();
1010        let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
1011
1012        let base2 = LinearKernel::new();
1013        let cov2 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
1014
1015        let result = MultiTaskKernelBuilder::new(2)
1016            .add_component(Box::new(base1), cov1)
1017            .add_component(Box::new(base2), cov2)
1018            .build_icm();
1019
1020        assert!(result.is_err());
1021    }
1022
1023    // ===== Integration Tests =====
1024
1025    #[test]
1026    fn test_multitask_with_rbf() {
1027        let base = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
1028        let cov = vec![
1029            vec![1.0, 0.8, 0.6],
1030            vec![0.8, 1.0, 0.7],
1031            vec![0.6, 0.7, 1.0],
1032        ];
1033        let icm = ICMKernel::new(Box::new(base), cov).unwrap();
1034
1035        // Same point, same task
1036        let x = TaskInput::new(vec![1.0, 2.0], 0);
1037        let k = icm.compute_tasks(&x, &x).unwrap();
1038        assert!((k - 1.0).abs() < 1e-10);
1039
1040        // Same point, different task
1041        let y = TaskInput::new(vec![1.0, 2.0], 1);
1042        let k = icm.compute_tasks(&x, &y).unwrap();
1043        assert!((k - 0.8).abs() < 1e-10);
1044
1045        // Different point, same task
1046        let z = TaskInput::new(vec![1.0, 3.0], 0);
1047        let k = icm.compute_tasks(&x, &z).unwrap();
1048        // RBF with distance 1.0: exp(-0.5 * 1) ≈ 0.6065
1049        assert!(k > 0.5 && k < 0.7);
1050    }
1051
1052    #[test]
1053    fn test_task_input_creation() {
1054        let input = TaskInput::new(vec![1.0, 2.0, 3.0], 0);
1055        assert_eq!(input.features, vec![1.0, 2.0, 3.0]);
1056        assert_eq!(input.task, 0);
1057
1058        let input = TaskInput::from_slice(&[4.0, 5.0], 2);
1059        assert_eq!(input.features, vec![4.0, 5.0]);
1060        assert_eq!(input.task, 2);
1061    }
1062}