1use crate::error::{KernelError, Result};
21use crate::types::Kernel;
22
23#[derive(Debug, Clone)]
25pub struct TaskInput {
26 pub features: Vec<f64>,
28 pub task: usize,
30}
31
32impl TaskInput {
33 pub fn new(features: Vec<f64>, task: usize) -> Self {
35 Self { features, task }
36 }
37
38 pub fn from_slice(features: &[f64], task: usize) -> Self {
40 Self {
41 features: features.to_vec(),
42 task,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct MultiTaskConfig {
50 pub num_tasks: usize,
52 pub normalize: bool,
54}
55
56impl MultiTaskConfig {
57 pub fn new(num_tasks: usize) -> Self {
59 Self {
60 num_tasks,
61 normalize: false,
62 }
63 }
64
65 pub fn with_normalization(mut self) -> Self {
67 self.normalize = true;
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
77pub struct IndexKernel {
78 task_covariance: Vec<Vec<f64>>,
80 num_tasks: usize,
82}
83
84impl IndexKernel {
85 pub fn new(task_covariance: Vec<Vec<f64>>) -> Result<Self> {
89 let num_tasks = task_covariance.len();
90 if num_tasks == 0 {
91 return Err(KernelError::InvalidParameter {
92 parameter: "task_covariance".to_string(),
93 value: "empty".to_string(),
94 reason: "must have at least one task".to_string(),
95 });
96 }
97
98 for (i, row) in task_covariance.iter().enumerate() {
100 if row.len() != num_tasks {
101 return Err(KernelError::InvalidParameter {
102 parameter: "task_covariance".to_string(),
103 value: format!("row {} has {} elements", i, row.len()),
104 reason: format!("expected {} elements (square matrix)", num_tasks),
105 });
106 }
107 }
108
109 Ok(Self {
110 task_covariance,
111 num_tasks,
112 })
113 }
114
115 pub fn identity(num_tasks: usize) -> Result<Self> {
117 let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
118 for (i, row) in cov.iter_mut().enumerate() {
119 row[i] = 1.0;
120 }
121 Self::new(cov)
122 }
123
124 pub fn uniform(num_tasks: usize, correlation: f64) -> Result<Self> {
126 if !(0.0..=1.0).contains(&correlation) {
127 return Err(KernelError::InvalidParameter {
128 parameter: "correlation".to_string(),
129 value: correlation.to_string(),
130 reason: "must be in [0, 1]".to_string(),
131 });
132 }
133
134 let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
135 for (i, row) in cov.iter_mut().enumerate() {
136 row[i] = 1.0;
137 }
138 Self::new(cov)
139 }
140
141 pub fn get_task_covariance(&self, task_i: usize, task_j: usize) -> Result<f64> {
143 if task_i >= self.num_tasks || task_j >= self.num_tasks {
144 return Err(KernelError::ComputationError(format!(
145 "Task index out of bounds: ({}, {}) for {} tasks",
146 task_i, task_j, self.num_tasks
147 )));
148 }
149 Ok(self.task_covariance[task_i][task_j])
150 }
151
152 pub fn num_tasks(&self) -> usize {
154 self.num_tasks
155 }
156
157 pub fn covariance_matrix(&self) -> &Vec<Vec<f64>> {
159 &self.task_covariance
160 }
161}
162
163pub struct ICMKernel {
174 base_kernel: Box<dyn Kernel>,
176 task_covariance: Vec<Vec<f64>>,
178 num_tasks: usize,
180}
181
182impl ICMKernel {
183 pub fn new(base_kernel: Box<dyn Kernel>, task_covariance: Vec<Vec<f64>>) -> Result<Self> {
189 let num_tasks = task_covariance.len();
190 if num_tasks == 0 {
191 return Err(KernelError::InvalidParameter {
192 parameter: "task_covariance".to_string(),
193 value: "empty".to_string(),
194 reason: "must have at least one task".to_string(),
195 });
196 }
197
198 for (i, row) in task_covariance.iter().enumerate() {
200 if row.len() != num_tasks {
201 return Err(KernelError::InvalidParameter {
202 parameter: "task_covariance".to_string(),
203 value: format!("row {} has {} elements", i, row.len()),
204 reason: format!("expected {} elements", num_tasks),
205 });
206 }
207 }
208
209 Ok(Self {
210 base_kernel,
211 task_covariance,
212 num_tasks,
213 })
214 }
215
216 pub fn independent(base_kernel: Box<dyn Kernel>, num_tasks: usize) -> Result<Self> {
218 let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
219 for (i, row) in cov.iter_mut().enumerate() {
220 row[i] = 1.0;
221 }
222 Self::new(base_kernel, cov)
223 }
224
225 pub fn uniform(
227 base_kernel: Box<dyn Kernel>,
228 num_tasks: usize,
229 correlation: f64,
230 ) -> Result<Self> {
231 if !(0.0..=1.0).contains(&correlation) {
232 return Err(KernelError::InvalidParameter {
233 parameter: "correlation".to_string(),
234 value: correlation.to_string(),
235 reason: "must be in [0, 1]".to_string(),
236 });
237 }
238
239 let mut cov = vec![vec![correlation; num_tasks]; num_tasks];
240 for (i, row) in cov.iter_mut().enumerate() {
241 row[i] = 1.0;
242 }
243 Self::new(base_kernel, cov)
244 }
245
246 pub fn from_rank1(base_kernel: Box<dyn Kernel>, task_variances: Vec<f64>) -> Result<Self> {
250 let num_tasks = task_variances.len();
251 let mut cov = vec![vec![0.0; num_tasks]; num_tasks];
252 for i in 0..num_tasks {
253 for j in 0..num_tasks {
254 cov[i][j] = task_variances[i].sqrt() * task_variances[j].sqrt();
255 }
256 }
257 Self::new(base_kernel, cov)
258 }
259
260 pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
262 if x.task >= self.num_tasks || y.task >= self.num_tasks {
263 return Err(KernelError::ComputationError(format!(
264 "Task index out of bounds: ({}, {}) for {} tasks",
265 x.task, y.task, self.num_tasks
266 )));
267 }
268
269 let k_features = self.base_kernel.compute(&x.features, &y.features)?;
270 let b_tasks = self.task_covariance[x.task][y.task];
271
272 Ok(b_tasks * k_features)
273 }
274
275 pub fn num_tasks(&self) -> usize {
277 self.num_tasks
278 }
279
280 pub fn task_covariance(&self) -> &Vec<Vec<f64>> {
282 &self.task_covariance
283 }
284
285 pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
287 let n = inputs.len();
288 let mut matrix = vec![vec![0.0; n]; n];
289
290 for i in 0..n {
291 for j in i..n {
292 let k = self.compute_tasks(&inputs[i], &inputs[j])?;
293 matrix[i][j] = k;
294 matrix[j][i] = k;
295 }
296 }
297
298 Ok(matrix)
299 }
300}
301
302struct LMCComponent {
304 kernel: Box<dyn Kernel>,
306 task_covariance: Vec<Vec<f64>>,
308}
309
310pub struct LMCKernel {
322 components: Vec<LMCComponent>,
324 num_tasks: usize,
326}
327
328impl LMCKernel {
329 pub fn new(num_tasks: usize) -> Self {
331 Self {
332 components: Vec::new(),
333 num_tasks,
334 }
335 }
336
337 pub fn add_component(
339 &mut self,
340 kernel: Box<dyn Kernel>,
341 task_covariance: Vec<Vec<f64>>,
342 ) -> Result<()> {
343 if task_covariance.len() != self.num_tasks {
345 return Err(KernelError::InvalidParameter {
346 parameter: "task_covariance".to_string(),
347 value: format!("{} rows", task_covariance.len()),
348 reason: format!("expected {} rows", self.num_tasks),
349 });
350 }
351
352 for (i, row) in task_covariance.iter().enumerate() {
353 if row.len() != self.num_tasks {
354 return Err(KernelError::InvalidParameter {
355 parameter: "task_covariance".to_string(),
356 value: format!("row {} has {} elements", i, row.len()),
357 reason: format!("expected {} elements", self.num_tasks),
358 });
359 }
360 }
361
362 self.components.push(LMCComponent {
363 kernel,
364 task_covariance,
365 });
366
367 Ok(())
368 }
369
370 pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
372 if x.task >= self.num_tasks || y.task >= self.num_tasks {
373 return Err(KernelError::ComputationError(format!(
374 "Task index out of bounds: ({}, {}) for {} tasks",
375 x.task, y.task, self.num_tasks
376 )));
377 }
378
379 let mut result = 0.0;
380 for component in &self.components {
381 let k_features = component.kernel.compute(&x.features, &y.features)?;
382 let b_tasks = component.task_covariance[x.task][y.task];
383 result += b_tasks * k_features;
384 }
385
386 Ok(result)
387 }
388
389 pub fn num_components(&self) -> usize {
391 self.components.len()
392 }
393
394 pub fn num_tasks(&self) -> usize {
396 self.num_tasks
397 }
398
399 pub fn compute_task_matrix(&self, inputs: &[TaskInput]) -> Result<Vec<Vec<f64>>> {
401 let n = inputs.len();
402 let mut matrix = vec![vec![0.0; n]; n];
403
404 for i in 0..n {
405 for j in i..n {
406 let k = self.compute_tasks(&inputs[i], &inputs[j])?;
407 matrix[i][j] = k;
408 matrix[j][i] = k;
409 }
410 }
411
412 Ok(matrix)
413 }
414}
415
416pub struct ICMKernelWrapper {
420 inner: ICMKernel,
421}
422
423impl ICMKernelWrapper {
424 pub fn new(inner: ICMKernel) -> Self {
426 Self { inner }
427 }
428
429 pub fn inner(&self) -> &ICMKernel {
431 &self.inner
432 }
433}
434
435impl Kernel for ICMKernelWrapper {
436 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
440 if x.is_empty() || y.is_empty() {
441 return Err(KernelError::ComputationError(
442 "Input must have at least task index".to_string(),
443 ));
444 }
445
446 let task_x = x[0] as usize;
447 let task_y = y[0] as usize;
448 let features_x = &x[1..];
449 let features_y = &y[1..];
450
451 let input_x = TaskInput::from_slice(features_x, task_x);
452 let input_y = TaskInput::from_slice(features_y, task_y);
453
454 self.inner.compute_tasks(&input_x, &input_y)
455 }
456
457 fn name(&self) -> &str {
458 "ICM"
459 }
460}
461
462pub struct LMCKernelWrapper {
464 inner: LMCKernel,
465}
466
467impl LMCKernelWrapper {
468 pub fn new(inner: LMCKernel) -> Self {
470 Self { inner }
471 }
472
473 pub fn inner(&self) -> &LMCKernel {
475 &self.inner
476 }
477}
478
479impl Kernel for LMCKernelWrapper {
480 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
482 if x.is_empty() || y.is_empty() {
483 return Err(KernelError::ComputationError(
484 "Input must have at least task index".to_string(),
485 ));
486 }
487
488 let task_x = x[0] as usize;
489 let task_y = y[0] as usize;
490 let features_x = &x[1..];
491 let features_y = &y[1..];
492
493 let input_x = TaskInput::from_slice(features_x, task_x);
494 let input_y = TaskInput::from_slice(features_y, task_y);
495
496 self.inner.compute_tasks(&input_x, &input_y)
497 }
498
499 fn name(&self) -> &str {
500 "LMC"
501 }
502}
503
504pub struct HadamardTaskKernel {
510 kernels: Vec<ICMKernel>,
512}
513
514impl HadamardTaskKernel {
515 pub fn new() -> Self {
517 Self {
518 kernels: Vec::new(),
519 }
520 }
521
522 pub fn add_kernel(&mut self, kernel: ICMKernel) -> Result<()> {
524 if !self.kernels.is_empty() && kernel.num_tasks() != self.kernels[0].num_tasks() {
525 return Err(KernelError::InvalidParameter {
526 parameter: "num_tasks".to_string(),
527 value: kernel.num_tasks().to_string(),
528 reason: format!("expected {}", self.kernels[0].num_tasks()),
529 });
530 }
531 self.kernels.push(kernel);
532 Ok(())
533 }
534
535 pub fn compute_tasks(&self, x: &TaskInput, y: &TaskInput) -> Result<f64> {
537 if self.kernels.is_empty() {
538 return Err(KernelError::ComputationError(
539 "No component kernels added".to_string(),
540 ));
541 }
542
543 let mut result = 1.0;
544 for kernel in &self.kernels {
545 result *= kernel.compute_tasks(x, y)?;
546 }
547 Ok(result)
548 }
549
550 pub fn num_tasks(&self) -> Option<usize> {
552 self.kernels.first().map(|k| k.num_tasks())
553 }
554}
555
556impl Default for HadamardTaskKernel {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562pub struct MultiTaskKernelBuilder {
564 num_tasks: usize,
565 base_kernels: Vec<Box<dyn Kernel>>,
566 task_covariances: Vec<Vec<Vec<f64>>>,
567}
568
569impl MultiTaskKernelBuilder {
570 pub fn new(num_tasks: usize) -> Self {
572 Self {
573 num_tasks,
574 base_kernels: Vec::new(),
575 task_covariances: Vec::new(),
576 }
577 }
578
579 pub fn add_component(
581 mut self,
582 kernel: Box<dyn Kernel>,
583 task_covariance: Vec<Vec<f64>>,
584 ) -> Self {
585 self.base_kernels.push(kernel);
586 self.task_covariances.push(task_covariance);
587 self
588 }
589
590 pub fn build_icm(self) -> Result<ICMKernel> {
592 if self.base_kernels.len() != 1 {
593 return Err(KernelError::InvalidParameter {
594 parameter: "components".to_string(),
595 value: self.base_kernels.len().to_string(),
596 reason: "ICM requires exactly one component".to_string(),
597 });
598 }
599
600 let kernel = self.base_kernels.into_iter().next().unwrap();
601 let cov = self.task_covariances.into_iter().next().unwrap();
602 ICMKernel::new(kernel, cov)
603 }
604
605 pub fn build_lmc(self) -> Result<LMCKernel> {
607 let mut lmc = LMCKernel::new(self.num_tasks);
608
609 for (kernel, cov) in self.base_kernels.into_iter().zip(self.task_covariances) {
610 lmc.add_component(kernel, cov)?;
611 }
612
613 Ok(lmc)
614 }
615}
616
617#[cfg(test)]
618#[allow(clippy::needless_range_loop)]
619mod tests {
620 use super::*;
621 use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
622
623 #[test]
626 fn test_index_kernel_basic() {
627 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
628 let kernel = IndexKernel::new(cov).unwrap();
629
630 assert_eq!(kernel.num_tasks(), 2);
631 assert!((kernel.get_task_covariance(0, 1).unwrap() - 0.5).abs() < 1e-10);
632 assert!((kernel.get_task_covariance(1, 1).unwrap() - 1.0).abs() < 1e-10);
633 }
634
635 #[test]
636 fn test_index_kernel_identity() {
637 let kernel = IndexKernel::identity(3).unwrap();
638
639 assert!((kernel.get_task_covariance(0, 0).unwrap() - 1.0).abs() < 1e-10);
640 assert!((kernel.get_task_covariance(0, 1).unwrap()).abs() < 1e-10);
641 assert!((kernel.get_task_covariance(1, 2).unwrap()).abs() < 1e-10);
642 }
643
644 #[test]
645 fn test_index_kernel_uniform() {
646 let kernel = IndexKernel::uniform(3, 0.5).unwrap();
647
648 assert!((kernel.get_task_covariance(0, 0).unwrap() - 1.0).abs() < 1e-10);
649 assert!((kernel.get_task_covariance(0, 1).unwrap() - 0.5).abs() < 1e-10);
650 assert!((kernel.get_task_covariance(1, 2).unwrap() - 0.5).abs() < 1e-10);
651 }
652
653 #[test]
654 fn test_index_kernel_invalid() {
655 let result = IndexKernel::new(vec![]);
657 assert!(result.is_err());
658
659 let result = IndexKernel::new(vec![vec![1.0, 0.5]]);
661 assert!(result.is_err());
662
663 let result = IndexKernel::uniform(3, 1.5);
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_index_kernel_out_of_bounds() {
670 let kernel = IndexKernel::identity(2).unwrap();
671 assert!(kernel.get_task_covariance(2, 0).is_err());
672 }
673
674 #[test]
677 fn test_icm_kernel_basic() {
678 let base = LinearKernel::new();
679 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
680 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
681
682 assert_eq!(icm.num_tasks(), 2);
683 }
684
685 #[test]
686 fn test_icm_kernel_compute() {
687 let base = LinearKernel::new();
688 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
689 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
690
691 let x = TaskInput::new(vec![1.0, 2.0], 0);
692 let y = TaskInput::new(vec![3.0, 4.0], 1);
693
694 let k = icm.compute_tasks(&x, &y).unwrap();
695 assert!((k - 5.5).abs() < 1e-10);
699 }
700
701 #[test]
702 fn test_icm_kernel_same_task() {
703 let base = LinearKernel::new();
704 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
705 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
706
707 let x = TaskInput::new(vec![1.0, 2.0], 0);
708 let y = TaskInput::new(vec![3.0, 4.0], 0);
709
710 let k = icm.compute_tasks(&x, &y).unwrap();
711 assert!((k - 11.0).abs() < 1e-10);
713 }
714
715 #[test]
716 fn test_icm_kernel_independent() {
717 let base = LinearKernel::new();
718 let icm = ICMKernel::independent(Box::new(base), 3).unwrap();
719
720 let x = TaskInput::new(vec![1.0], 0);
721 let y = TaskInput::new(vec![1.0], 1);
722
723 let k = icm.compute_tasks(&x, &y).unwrap();
725 assert!(k.abs() < 1e-10);
726
727 let z = TaskInput::new(vec![1.0], 0);
729 let k = icm.compute_tasks(&x, &z).unwrap();
730 assert!((k - 1.0).abs() < 1e-10);
731 }
732
733 #[test]
734 fn test_icm_kernel_uniform() {
735 let base = LinearKernel::new();
736 let icm = ICMKernel::uniform(Box::new(base), 2, 0.8).unwrap();
737
738 let x = TaskInput::new(vec![1.0], 0);
739 let y = TaskInput::new(vec![1.0], 1);
740
741 let k = icm.compute_tasks(&x, &y).unwrap();
742 assert!((k - 0.8).abs() < 1e-10);
744 }
745
746 #[test]
747 fn test_icm_kernel_rank1() {
748 let base = LinearKernel::new();
749 let variances = vec![1.0, 4.0]; let icm = ICMKernel::from_rank1(Box::new(base), variances).unwrap();
751
752 let x = TaskInput::new(vec![1.0], 0);
754 let y = TaskInput::new(vec![1.0], 1);
755
756 let k = icm.compute_tasks(&x, &y).unwrap();
757 assert!((k - 2.0).abs() < 1e-10);
758 }
759
760 #[test]
761 fn test_icm_kernel_matrix() {
762 let base = LinearKernel::new();
763 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
764 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
765
766 let inputs = vec![
767 TaskInput::new(vec![1.0], 0),
768 TaskInput::new(vec![1.0], 1),
769 TaskInput::new(vec![2.0], 0),
770 ];
771
772 let matrix = icm.compute_task_matrix(&inputs).unwrap();
773
774 assert_eq!(matrix.len(), 3);
775 for i in 0..3 {
777 for j in 0..3 {
778 assert!(
779 (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
780 "Matrix not symmetric at ({}, {})",
781 i,
782 j
783 );
784 }
785 }
786 }
787
788 #[test]
789 fn test_icm_kernel_invalid_task() {
790 let base = LinearKernel::new();
791 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
792 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
793
794 let x = TaskInput::new(vec![1.0], 0);
795 let y = TaskInput::new(vec![1.0], 5); assert!(icm.compute_tasks(&x, &y).is_err());
798 }
799
800 #[test]
803 fn test_lmc_kernel_basic() {
804 let mut lmc = LMCKernel::new(2);
805
806 let base1 = LinearKernel::new();
807 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
808 lmc.add_component(Box::new(base1), cov1).unwrap();
809
810 assert_eq!(lmc.num_tasks(), 2);
811 assert_eq!(lmc.num_components(), 1);
812 }
813
814 #[test]
815 fn test_lmc_kernel_compute() {
816 let mut lmc = LMCKernel::new(2);
817
818 let base1 = LinearKernel::new();
820 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
821 lmc.add_component(Box::new(base1), cov1).unwrap();
822
823 let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).unwrap();
825 let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
826 lmc.add_component(Box::new(base2), cov2).unwrap();
827
828 let x = TaskInput::new(vec![1.0, 0.0], 0);
829 let y = TaskInput::new(vec![1.0, 0.0], 1);
830
831 let k = lmc.compute_tasks(&x, &y).unwrap();
832 assert!((k - 1.5).abs() < 1e-10);
836 }
837
838 #[test]
839 fn test_lmc_kernel_matrix() {
840 let mut lmc = LMCKernel::new(2);
841
842 let base = LinearKernel::new();
843 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
844 lmc.add_component(Box::new(base), cov).unwrap();
845
846 let inputs = vec![TaskInput::new(vec![1.0], 0), TaskInput::new(vec![1.0], 1)];
847
848 let matrix = lmc.compute_task_matrix(&inputs).unwrap();
849 assert_eq!(matrix.len(), 2);
850
851 assert!((matrix[0][1] - matrix[1][0]).abs() < 1e-10);
853 }
854
855 #[test]
856 fn test_lmc_kernel_invalid_dimensions() {
857 let mut lmc = LMCKernel::new(2);
858
859 let base = LinearKernel::new();
860 let cov = vec![
861 vec![1.0, 0.5, 0.3],
862 vec![0.5, 1.0, 0.4],
863 vec![0.3, 0.4, 1.0],
864 ];
865
866 assert!(lmc.add_component(Box::new(base), cov).is_err());
868 }
869
870 #[test]
873 fn test_icm_wrapper() {
874 let base = LinearKernel::new();
875 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
876 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
877 let wrapper = ICMKernelWrapper::new(icm);
878
879 let x = vec![0.0, 1.0, 2.0]; let y = vec![1.0, 3.0, 4.0]; let k = wrapper.compute(&x, &y).unwrap();
884 assert!((k - 5.5).abs() < 1e-10);
886
887 assert_eq!(wrapper.name(), "ICM");
888 }
889
890 #[test]
891 fn test_lmc_wrapper() {
892 let mut lmc = LMCKernel::new(2);
893 let base = LinearKernel::new();
894 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
895 lmc.add_component(Box::new(base), cov).unwrap();
896
897 let wrapper = LMCKernelWrapper::new(lmc);
898
899 let x = vec![0.0, 1.0]; let y = vec![1.0, 1.0]; let k = wrapper.compute(&x, &y).unwrap();
903 assert!((k - 0.5).abs() < 1e-10);
904
905 assert_eq!(wrapper.name(), "LMC");
906 }
907
908 #[test]
909 fn test_wrapper_empty_input() {
910 let base = LinearKernel::new();
911 let cov = vec![vec![1.0]];
912 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
913 let wrapper = ICMKernelWrapper::new(icm);
914
915 assert!(wrapper.compute(&[], &[0.0, 1.0]).is_err());
916 }
917
918 #[test]
921 fn test_hadamard_task_kernel() {
922 let mut hadamard = HadamardTaskKernel::new();
923
924 let base1 = LinearKernel::new();
925 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
926 let icm1 = ICMKernel::new(Box::new(base1), cov1).unwrap();
927 hadamard.add_kernel(icm1).unwrap();
928
929 let base2 = LinearKernel::new();
930 let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
931 let icm2 = ICMKernel::new(Box::new(base2), cov2).unwrap();
932 hadamard.add_kernel(icm2).unwrap();
933
934 let x = TaskInput::new(vec![1.0], 0);
935 let y = TaskInput::new(vec![1.0], 1);
936
937 let k = hadamard.compute_tasks(&x, &y).unwrap();
938 assert!((k - 0.5).abs() < 1e-10);
942 }
943
944 #[test]
945 fn test_hadamard_task_kernel_empty() {
946 let hadamard = HadamardTaskKernel::new();
947 let x = TaskInput::new(vec![1.0], 0);
948 let y = TaskInput::new(vec![1.0], 0);
949
950 assert!(hadamard.compute_tasks(&x, &y).is_err());
951 }
952
953 #[test]
954 fn test_hadamard_mismatched_tasks() {
955 let mut hadamard = HadamardTaskKernel::new();
956
957 let base1 = LinearKernel::new();
958 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
959 let icm1 = ICMKernel::new(Box::new(base1), cov1).unwrap();
960 hadamard.add_kernel(icm1).unwrap();
961
962 let base2 = LinearKernel::new();
964 let cov2 = vec![
965 vec![1.0, 0.5, 0.3],
966 vec![0.5, 1.0, 0.4],
967 vec![0.3, 0.4, 1.0],
968 ];
969 let icm2 = ICMKernel::new(Box::new(base2), cov2).unwrap();
970
971 assert!(hadamard.add_kernel(icm2).is_err());
972 }
973
974 #[test]
977 fn test_builder_icm() {
978 let base = LinearKernel::new();
979 let cov = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
980
981 let icm = MultiTaskKernelBuilder::new(2)
982 .add_component(Box::new(base), cov)
983 .build_icm()
984 .unwrap();
985
986 assert_eq!(icm.num_tasks(), 2);
987 }
988
989 #[test]
990 fn test_builder_lmc() {
991 let base1 = LinearKernel::new();
992 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
993
994 let base2 = RbfKernel::new(RbfKernelConfig::new(1.0)).unwrap();
995 let cov2 = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
996
997 let lmc = MultiTaskKernelBuilder::new(2)
998 .add_component(Box::new(base1), cov1)
999 .add_component(Box::new(base2), cov2)
1000 .build_lmc()
1001 .unwrap();
1002
1003 assert_eq!(lmc.num_tasks(), 2);
1004 assert_eq!(lmc.num_components(), 2);
1005 }
1006
1007 #[test]
1008 fn test_builder_icm_wrong_components() {
1009 let base1 = LinearKernel::new();
1010 let cov1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
1011
1012 let base2 = LinearKernel::new();
1013 let cov2 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
1014
1015 let result = MultiTaskKernelBuilder::new(2)
1016 .add_component(Box::new(base1), cov1)
1017 .add_component(Box::new(base2), cov2)
1018 .build_icm();
1019
1020 assert!(result.is_err());
1021 }
1022
1023 #[test]
1026 fn test_multitask_with_rbf() {
1027 let base = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
1028 let cov = vec![
1029 vec![1.0, 0.8, 0.6],
1030 vec![0.8, 1.0, 0.7],
1031 vec![0.6, 0.7, 1.0],
1032 ];
1033 let icm = ICMKernel::new(Box::new(base), cov).unwrap();
1034
1035 let x = TaskInput::new(vec![1.0, 2.0], 0);
1037 let k = icm.compute_tasks(&x, &x).unwrap();
1038 assert!((k - 1.0).abs() < 1e-10);
1039
1040 let y = TaskInput::new(vec![1.0, 2.0], 1);
1042 let k = icm.compute_tasks(&x, &y).unwrap();
1043 assert!((k - 0.8).abs() < 1e-10);
1044
1045 let z = TaskInput::new(vec![1.0, 3.0], 0);
1047 let k = icm.compute_tasks(&x, &z).unwrap();
1048 assert!(k > 0.5 && k < 0.7);
1050 }
1051
1052 #[test]
1053 fn test_task_input_creation() {
1054 let input = TaskInput::new(vec![1.0, 2.0, 3.0], 0);
1055 assert_eq!(input.features, vec![1.0, 2.0, 3.0]);
1056 assert_eq!(input.task, 0);
1057
1058 let input = TaskInput::from_slice(&[4.0, 5.0], 2);
1059 assert_eq!(input.features, vec![4.0, 5.0]);
1060 assert_eq!(input.task, 2);
1061 }
1062}