1use crate::error::{KernelError, Result};
30use crate::types::Kernel;
31
32#[derive(Debug, Clone)]
39pub struct ArdRbfKernel {
40 length_scales: Vec<f64>,
42 variance: f64,
44}
45
46impl ArdRbfKernel {
47 pub fn new(length_scales: Vec<f64>) -> Result<Self> {
59 Self::with_variance(length_scales, 1.0)
60 }
61
62 pub fn with_variance(length_scales: Vec<f64>, variance: f64) -> Result<Self> {
68 if length_scales.is_empty() {
69 return Err(KernelError::InvalidParameter {
70 parameter: "length_scales".to_string(),
71 value: "[]".to_string(),
72 reason: "must have at least one dimension".to_string(),
73 });
74 }
75
76 for (i, &ls) in length_scales.iter().enumerate() {
77 if ls <= 0.0 {
78 return Err(KernelError::InvalidParameter {
79 parameter: format!("length_scales[{}]", i),
80 value: ls.to_string(),
81 reason: "all length scales must be positive".to_string(),
82 });
83 }
84 }
85
86 if variance <= 0.0 {
87 return Err(KernelError::InvalidParameter {
88 parameter: "variance".to_string(),
89 value: variance.to_string(),
90 reason: "variance must be positive".to_string(),
91 });
92 }
93
94 Ok(Self {
95 length_scales,
96 variance,
97 })
98 }
99
100 pub fn length_scales(&self) -> &[f64] {
102 &self.length_scales
103 }
104
105 pub fn variance(&self) -> f64 {
107 self.variance
108 }
109
110 pub fn ndim(&self) -> usize {
112 self.length_scales.len()
113 }
114
115 pub fn compute_gradient(&self, x: &[f64], y: &[f64]) -> Result<KernelGradient> {
123 if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
124 return Err(KernelError::DimensionMismatch {
125 expected: vec![self.length_scales.len()],
126 got: vec![x.len(), y.len()],
127 context: "ARD RBF kernel gradient".to_string(),
128 });
129 }
130
131 let mut sum_scaled_sq = 0.0;
133 let mut scaled_sq_diffs = Vec::with_capacity(self.length_scales.len());
134
135 for i in 0..self.length_scales.len() {
136 let diff = x[i] - y[i];
137 let ls = self.length_scales[i];
138 let scaled_sq = diff * diff / (ls * ls);
139 scaled_sq_diffs.push(scaled_sq);
140 sum_scaled_sq += scaled_sq;
141 }
142
143 let exp_term = (-0.5 * sum_scaled_sq).exp();
144 let k_value = self.variance * exp_term;
145
146 let grad_length_scales: Vec<f64> = scaled_sq_diffs
148 .iter()
149 .enumerate()
150 .map(|(i, &sq_diff)| {
151 let ls = self.length_scales[i];
152 k_value * sq_diff / ls
153 })
154 .collect();
155
156 let grad_variance = exp_term;
158
159 Ok(KernelGradient {
160 value: k_value,
161 grad_length_scales,
162 grad_variance,
163 })
164 }
165}
166
167impl Kernel for ArdRbfKernel {
168 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
169 if x.len() != self.length_scales.len() {
170 return Err(KernelError::DimensionMismatch {
171 expected: vec![self.length_scales.len()],
172 got: vec![x.len()],
173 context: "ARD RBF kernel".to_string(),
174 });
175 }
176 if y.len() != self.length_scales.len() {
177 return Err(KernelError::DimensionMismatch {
178 expected: vec![self.length_scales.len()],
179 got: vec![y.len()],
180 context: "ARD RBF kernel".to_string(),
181 });
182 }
183
184 let mut sum_scaled_sq = 0.0;
185 for i in 0..self.length_scales.len() {
186 let diff = x[i] - y[i];
187 let ls = self.length_scales[i];
188 sum_scaled_sq += (diff * diff) / (ls * ls);
189 }
190
191 Ok(self.variance * (-0.5 * sum_scaled_sq).exp())
192 }
193
194 fn name(&self) -> &str {
195 "ARD-RBF"
196 }
197}
198
199#[derive(Debug, Clone)]
203pub struct ArdMaternKernel {
204 length_scales: Vec<f64>,
206 variance: f64,
208 nu: f64,
210}
211
212impl ArdMaternKernel {
213 pub fn new(length_scales: Vec<f64>, nu: f64) -> Result<Self> {
219 Self::with_variance(length_scales, nu, 1.0)
220 }
221
222 pub fn with_variance(length_scales: Vec<f64>, nu: f64, variance: f64) -> Result<Self> {
224 if length_scales.is_empty() {
225 return Err(KernelError::InvalidParameter {
226 parameter: "length_scales".to_string(),
227 value: "[]".to_string(),
228 reason: "must have at least one dimension".to_string(),
229 });
230 }
231
232 for (i, &ls) in length_scales.iter().enumerate() {
233 if ls <= 0.0 {
234 return Err(KernelError::InvalidParameter {
235 parameter: format!("length_scales[{}]", i),
236 value: ls.to_string(),
237 reason: "all length scales must be positive".to_string(),
238 });
239 }
240 }
241
242 if !((nu - 0.5).abs() < 1e-10 || (nu - 1.5).abs() < 1e-10 || (nu - 2.5).abs() < 1e-10) {
243 return Err(KernelError::InvalidParameter {
244 parameter: "nu".to_string(),
245 value: nu.to_string(),
246 reason: "nu must be 0.5, 1.5, or 2.5".to_string(),
247 });
248 }
249
250 if variance <= 0.0 {
251 return Err(KernelError::InvalidParameter {
252 parameter: "variance".to_string(),
253 value: variance.to_string(),
254 reason: "variance must be positive".to_string(),
255 });
256 }
257
258 Ok(Self {
259 length_scales,
260 variance,
261 nu,
262 })
263 }
264
265 pub fn exponential(length_scales: Vec<f64>) -> Result<Self> {
267 Self::new(length_scales, 0.5)
268 }
269
270 pub fn nu_3_2(length_scales: Vec<f64>) -> Result<Self> {
272 Self::new(length_scales, 1.5)
273 }
274
275 pub fn nu_5_2(length_scales: Vec<f64>) -> Result<Self> {
277 Self::new(length_scales, 2.5)
278 }
279
280 pub fn length_scales(&self) -> &[f64] {
282 &self.length_scales
283 }
284
285 pub fn variance(&self) -> f64 {
287 self.variance
288 }
289
290 pub fn nu(&self) -> f64 {
292 self.nu
293 }
294
295 fn scaled_distance(&self, x: &[f64], y: &[f64]) -> f64 {
297 let mut sum = 0.0;
298 for i in 0..self.length_scales.len() {
299 let diff = x[i] - y[i];
300 let ls = self.length_scales[i];
301 sum += (diff * diff) / (ls * ls);
302 }
303 sum.sqrt()
304 }
305}
306
307impl Kernel for ArdMaternKernel {
308 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
309 if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
310 return Err(KernelError::DimensionMismatch {
311 expected: vec![self.length_scales.len()],
312 got: vec![x.len(), y.len()],
313 context: "ARD Matérn kernel".to_string(),
314 });
315 }
316
317 let r = self.scaled_distance(x, y);
318
319 if r < 1e-10 {
321 return Ok(self.variance);
322 }
323
324 let sqrt_2nu = (2.0 * self.nu).sqrt();
325 let scaled_r = sqrt_2nu * r;
326
327 let result = if (self.nu - 0.5).abs() < 1e-10 {
328 (-scaled_r).exp()
330 } else if (self.nu - 1.5).abs() < 1e-10 {
331 (1.0 + scaled_r) * (-scaled_r).exp()
333 } else {
334 (1.0 + scaled_r + scaled_r * scaled_r / 3.0) * (-scaled_r).exp()
336 };
337
338 Ok(self.variance * result)
339 }
340
341 fn name(&self) -> &str {
342 "ARD-Matérn"
343 }
344}
345
346#[derive(Debug, Clone)]
350pub struct ArdRationalQuadraticKernel {
351 length_scales: Vec<f64>,
353 variance: f64,
355 alpha: f64,
357}
358
359impl ArdRationalQuadraticKernel {
360 pub fn new(length_scales: Vec<f64>, alpha: f64) -> Result<Self> {
362 Self::with_variance(length_scales, alpha, 1.0)
363 }
364
365 pub fn with_variance(length_scales: Vec<f64>, alpha: f64, variance: f64) -> Result<Self> {
367 if length_scales.is_empty() {
368 return Err(KernelError::InvalidParameter {
369 parameter: "length_scales".to_string(),
370 value: "[]".to_string(),
371 reason: "must have at least one dimension".to_string(),
372 });
373 }
374
375 for (i, &ls) in length_scales.iter().enumerate() {
376 if ls <= 0.0 {
377 return Err(KernelError::InvalidParameter {
378 parameter: format!("length_scales[{}]", i),
379 value: ls.to_string(),
380 reason: "all length scales must be positive".to_string(),
381 });
382 }
383 }
384
385 if alpha <= 0.0 {
386 return Err(KernelError::InvalidParameter {
387 parameter: "alpha".to_string(),
388 value: alpha.to_string(),
389 reason: "alpha must be positive".to_string(),
390 });
391 }
392
393 if variance <= 0.0 {
394 return Err(KernelError::InvalidParameter {
395 parameter: "variance".to_string(),
396 value: variance.to_string(),
397 reason: "variance must be positive".to_string(),
398 });
399 }
400
401 Ok(Self {
402 length_scales,
403 variance,
404 alpha,
405 })
406 }
407
408 pub fn length_scales(&self) -> &[f64] {
410 &self.length_scales
411 }
412
413 pub fn variance(&self) -> f64 {
415 self.variance
416 }
417
418 pub fn alpha(&self) -> f64 {
420 self.alpha
421 }
422}
423
424impl Kernel for ArdRationalQuadraticKernel {
425 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
426 if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
427 return Err(KernelError::DimensionMismatch {
428 expected: vec![self.length_scales.len()],
429 got: vec![x.len(), y.len()],
430 context: "ARD Rational Quadratic kernel".to_string(),
431 });
432 }
433
434 let mut sum_scaled_sq = 0.0;
435 for i in 0..self.length_scales.len() {
436 let diff = x[i] - y[i];
437 let ls = self.length_scales[i];
438 sum_scaled_sq += (diff * diff) / (ls * ls);
439 }
440
441 let term = 1.0 + sum_scaled_sq / (2.0 * self.alpha);
442 Ok(self.variance * term.powf(-self.alpha))
443 }
444
445 fn name(&self) -> &str {
446 "ARD-RationalQuadratic"
447 }
448}
449
450#[derive(Debug, Clone)]
452pub struct KernelGradient {
453 pub value: f64,
455 pub grad_length_scales: Vec<f64>,
457 pub grad_variance: f64,
459}
460
461#[derive(Debug, Clone)]
467pub struct WhiteNoiseKernel {
468 variance: f64,
470}
471
472impl WhiteNoiseKernel {
473 pub fn new(variance: f64) -> Result<Self> {
478 if variance <= 0.0 {
479 return Err(KernelError::InvalidParameter {
480 parameter: "variance".to_string(),
481 value: variance.to_string(),
482 reason: "variance must be positive".to_string(),
483 });
484 }
485 Ok(Self { variance })
486 }
487
488 pub fn variance(&self) -> f64 {
490 self.variance
491 }
492}
493
494impl Kernel for WhiteNoiseKernel {
495 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
496 if x.len() != y.len() {
497 return Err(KernelError::DimensionMismatch {
498 expected: vec![x.len()],
499 got: vec![y.len()],
500 context: "White Noise kernel".to_string(),
501 });
502 }
503
504 let is_same = x.iter().zip(y.iter()).all(|(a, b)| (a - b).abs() < 1e-10);
506
507 if is_same {
508 Ok(self.variance)
509 } else {
510 Ok(0.0)
511 }
512 }
513
514 fn name(&self) -> &str {
515 "WhiteNoise"
516 }
517}
518
519#[derive(Debug, Clone)]
523pub struct ConstantKernel {
524 variance: f64,
526}
527
528impl ConstantKernel {
529 pub fn new(variance: f64) -> Result<Self> {
531 if variance <= 0.0 {
532 return Err(KernelError::InvalidParameter {
533 parameter: "variance".to_string(),
534 value: variance.to_string(),
535 reason: "variance must be positive".to_string(),
536 });
537 }
538 Ok(Self { variance })
539 }
540
541 pub fn variance(&self) -> f64 {
543 self.variance
544 }
545}
546
547impl Kernel for ConstantKernel {
548 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
549 if x.len() != y.len() {
550 return Err(KernelError::DimensionMismatch {
551 expected: vec![x.len()],
552 got: vec![y.len()],
553 context: "Constant kernel".to_string(),
554 });
555 }
556 Ok(self.variance)
557 }
558
559 fn name(&self) -> &str {
560 "Constant"
561 }
562}
563
564#[derive(Debug, Clone)]
571pub struct DotProductKernel {
572 variance: f64,
574 variance_bias: f64,
576}
577
578impl DotProductKernel {
579 pub fn new(variance: f64, variance_bias: f64) -> Result<Self> {
581 if variance < 0.0 {
582 return Err(KernelError::InvalidParameter {
583 parameter: "variance".to_string(),
584 value: variance.to_string(),
585 reason: "variance must be non-negative".to_string(),
586 });
587 }
588 if variance_bias < 0.0 {
589 return Err(KernelError::InvalidParameter {
590 parameter: "variance_bias".to_string(),
591 value: variance_bias.to_string(),
592 reason: "variance_bias must be non-negative".to_string(),
593 });
594 }
595 Ok(Self {
596 variance,
597 variance_bias,
598 })
599 }
600
601 pub fn simple() -> Self {
603 Self {
604 variance: 1.0,
605 variance_bias: 0.0,
606 }
607 }
608
609 pub fn variance(&self) -> f64 {
611 self.variance
612 }
613
614 pub fn variance_bias(&self) -> f64 {
616 self.variance_bias
617 }
618}
619
620impl Kernel for DotProductKernel {
621 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
622 if x.len() != y.len() {
623 return Err(KernelError::DimensionMismatch {
624 expected: vec![x.len()],
625 got: vec![y.len()],
626 context: "Dot Product kernel".to_string(),
627 });
628 }
629
630 let dot: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
631 Ok(self.variance_bias + self.variance * dot)
632 }
633
634 fn name(&self) -> &str {
635 "DotProduct"
636 }
637}
638
639#[derive(Debug, Clone)]
645pub struct ScaledKernel<K: Kernel> {
646 kernel: K,
648 variance: f64,
650}
651
652impl<K: Kernel> ScaledKernel<K> {
653 pub fn new(kernel: K, variance: f64) -> Result<Self> {
655 if variance <= 0.0 {
656 return Err(KernelError::InvalidParameter {
657 parameter: "variance".to_string(),
658 value: variance.to_string(),
659 reason: "variance must be positive".to_string(),
660 });
661 }
662 Ok(Self { kernel, variance })
663 }
664
665 pub fn kernel(&self) -> &K {
667 &self.kernel
668 }
669
670 pub fn variance(&self) -> f64 {
672 self.variance
673 }
674}
675
676impl<K: Kernel> Kernel for ScaledKernel<K> {
677 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
678 let base_value = self.kernel.compute(x, y)?;
679 Ok(self.variance * base_value)
680 }
681
682 fn name(&self) -> &str {
683 "Scaled"
684 }
685
686 fn is_psd(&self) -> bool {
687 self.kernel.is_psd()
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
698 fn test_ard_rbf_kernel_basic() {
699 let kernel = ArdRbfKernel::new(vec![1.0, 1.0, 1.0]).unwrap();
700 assert_eq!(kernel.name(), "ARD-RBF");
701 assert_eq!(kernel.ndim(), 3);
702
703 let x = vec![1.0, 2.0, 3.0];
704 let y = vec![1.0, 2.0, 3.0];
705
706 let sim = kernel.compute(&x, &y).unwrap();
708 assert!((sim - 1.0).abs() < 1e-10);
709 }
710
711 #[test]
712 fn test_ard_rbf_kernel_different_length_scales() {
713 let kernel = ArdRbfKernel::new(vec![10.0, 1.0, 1.0]).unwrap();
715
716 let x = vec![0.0, 0.0, 0.0];
717 let y1 = vec![1.0, 0.0, 0.0]; let y2 = vec![0.0, 1.0, 0.0]; let sim1 = kernel.compute(&x, &y1).unwrap();
721 let sim2 = kernel.compute(&x, &y2).unwrap();
722
723 assert!(sim1 > sim2);
725 }
726
727 #[test]
728 fn test_ard_rbf_kernel_with_variance() {
729 let kernel = ArdRbfKernel::with_variance(vec![1.0, 1.0], 2.0).unwrap();
730 assert!((kernel.variance() - 2.0).abs() < 1e-10);
731
732 let x = vec![0.0, 0.0];
733 let sim = kernel.compute(&x, &x).unwrap();
734 assert!((sim - 2.0).abs() < 1e-10); }
736
737 #[test]
738 fn test_ard_rbf_kernel_gradient() {
739 let kernel = ArdRbfKernel::new(vec![1.0, 2.0]).unwrap();
740 let x = vec![0.0, 0.0];
741 let y = vec![1.0, 1.0];
742
743 let grad = kernel.compute_gradient(&x, &y).unwrap();
744
745 let value = kernel.compute(&x, &y).unwrap();
747 assert!((grad.value - value).abs() < 1e-10);
748
749 assert_eq!(grad.grad_length_scales.len(), 2);
751 }
752
753 #[test]
754 fn test_ard_rbf_kernel_invalid_empty() {
755 let result = ArdRbfKernel::new(vec![]);
756 assert!(result.is_err());
757 }
758
759 #[test]
760 fn test_ard_rbf_kernel_invalid_negative() {
761 let result = ArdRbfKernel::new(vec![1.0, -1.0, 1.0]);
762 assert!(result.is_err());
763 }
764
765 #[test]
766 fn test_ard_rbf_kernel_invalid_variance() {
767 let result = ArdRbfKernel::with_variance(vec![1.0], 0.0);
768 assert!(result.is_err());
769 }
770
771 #[test]
772 fn test_ard_rbf_kernel_dimension_mismatch() {
773 let kernel = ArdRbfKernel::new(vec![1.0, 1.0]).unwrap();
774 let x = vec![1.0, 2.0, 3.0]; let y = vec![1.0, 2.0]; assert!(kernel.compute(&x, &y).is_err());
778 }
779
780 #[test]
781 fn test_ard_rbf_kernel_symmetry() {
782 let kernel = ArdRbfKernel::new(vec![1.0, 2.0, 0.5]).unwrap();
783 let x = vec![1.0, 2.0, 3.0];
784 let y = vec![4.0, 5.0, 6.0];
785
786 let k_xy = kernel.compute(&x, &y).unwrap();
787 let k_yx = kernel.compute(&y, &x).unwrap();
788 assert!((k_xy - k_yx).abs() < 1e-10);
789 }
790
791 #[test]
794 fn test_ard_matern_kernel_nu_3_2() {
795 let kernel = ArdMaternKernel::nu_3_2(vec![1.0, 1.0]).unwrap();
796 assert_eq!(kernel.name(), "ARD-Matérn");
797 assert!((kernel.nu() - 1.5).abs() < 1e-10);
798
799 let x = vec![0.0, 0.0];
800 let sim = kernel.compute(&x, &x).unwrap();
801 assert!((sim - 1.0).abs() < 1e-10);
802 }
803
804 #[test]
805 fn test_ard_matern_kernel_nu_5_2() {
806 let kernel = ArdMaternKernel::nu_5_2(vec![1.0, 2.0]).unwrap();
807 assert!((kernel.nu() - 2.5).abs() < 1e-10);
808
809 let x = vec![0.0, 0.0];
810 let y = vec![0.5, 0.5];
811 let sim = kernel.compute(&x, &y).unwrap();
812 assert!(sim > 0.0 && sim < 1.0);
813 }
814
815 #[test]
816 fn test_ard_matern_kernel_exponential() {
817 let kernel = ArdMaternKernel::exponential(vec![1.0]).unwrap();
818 assert!((kernel.nu() - 0.5).abs() < 1e-10);
819 }
820
821 #[test]
822 fn test_ard_matern_kernel_invalid_nu() {
823 let result = ArdMaternKernel::new(vec![1.0], 1.0);
825 assert!(result.is_err());
826 }
827
828 #[test]
829 fn test_ard_matern_kernel_different_length_scales() {
830 let kernel = ArdMaternKernel::nu_3_2(vec![10.0, 1.0]).unwrap();
831
832 let x = vec![0.0, 0.0];
833 let y1 = vec![1.0, 0.0];
834 let y2 = vec![0.0, 1.0];
835
836 let sim1 = kernel.compute(&x, &y1).unwrap();
837 let sim2 = kernel.compute(&x, &y2).unwrap();
838
839 assert!(sim1 > sim2);
841 }
842
843 #[test]
846 fn test_ard_rq_kernel_basic() {
847 let kernel = ArdRationalQuadraticKernel::new(vec![1.0, 1.0], 2.0).unwrap();
848 assert_eq!(kernel.name(), "ARD-RationalQuadratic");
849
850 let x = vec![0.0, 0.0];
851 let sim = kernel.compute(&x, &x).unwrap();
852 assert!((sim - 1.0).abs() < 1e-10);
853 }
854
855 #[test]
856 fn test_ard_rq_kernel_with_variance() {
857 let kernel = ArdRationalQuadraticKernel::with_variance(vec![1.0], 2.0, 3.0).unwrap();
858 assert!((kernel.variance() - 3.0).abs() < 1e-10);
859
860 let x = vec![0.0];
861 let sim = kernel.compute(&x, &x).unwrap();
862 assert!((sim - 3.0).abs() < 1e-10);
863 }
864
865 #[test]
868 fn test_white_noise_kernel_same_point() {
869 let kernel = WhiteNoiseKernel::new(0.1).unwrap();
870 assert_eq!(kernel.name(), "WhiteNoise");
871
872 let x = vec![1.0, 2.0, 3.0];
873 let sim = kernel.compute(&x, &x).unwrap();
874 assert!((sim - 0.1).abs() < 1e-10);
875 }
876
877 #[test]
878 fn test_white_noise_kernel_different_points() {
879 let kernel = WhiteNoiseKernel::new(0.1).unwrap();
880
881 let x = vec![1.0, 2.0, 3.0];
882 let y = vec![1.0, 2.0, 3.1]; let sim = kernel.compute(&x, &y).unwrap();
884 assert!(sim.abs() < 1e-10); }
886
887 #[test]
888 fn test_white_noise_kernel_invalid() {
889 let result = WhiteNoiseKernel::new(0.0);
890 assert!(result.is_err());
891
892 let result = WhiteNoiseKernel::new(-1.0);
893 assert!(result.is_err());
894 }
895
896 #[test]
899 fn test_constant_kernel() {
900 let kernel = ConstantKernel::new(2.5).unwrap();
901 assert_eq!(kernel.name(), "Constant");
902
903 let x = vec![1.0, 2.0];
904 let y = vec![3.0, 4.0];
905
906 let sim = kernel.compute(&x, &y).unwrap();
907 assert!((sim - 2.5).abs() < 1e-10);
908 }
909
910 #[test]
911 fn test_constant_kernel_invalid() {
912 assert!(ConstantKernel::new(0.0).is_err());
913 assert!(ConstantKernel::new(-1.0).is_err());
914 }
915
916 #[test]
919 fn test_dot_product_kernel_simple() {
920 let kernel = DotProductKernel::simple();
921 assert_eq!(kernel.name(), "DotProduct");
922
923 let x = vec![1.0, 2.0, 3.0];
924 let y = vec![4.0, 5.0, 6.0];
925
926 let sim = kernel.compute(&x, &y).unwrap();
928 assert!((sim - 32.0).abs() < 1e-10);
929 }
930
931 #[test]
932 fn test_dot_product_kernel_with_bias() {
933 let kernel = DotProductKernel::new(1.0, 5.0).unwrap();
934
935 let x = vec![1.0, 0.0];
936 let y = vec![0.0, 1.0]; let sim = kernel.compute(&x, &y).unwrap();
940 assert!((sim - 5.0).abs() < 1e-10);
941 }
942
943 #[test]
944 fn test_dot_product_kernel_with_variance() {
945 let kernel = DotProductKernel::new(2.0, 0.0).unwrap();
946
947 let x = vec![1.0, 2.0];
948 let y = vec![3.0, 4.0];
949
950 let sim = kernel.compute(&x, &y).unwrap();
952 assert!((sim - 22.0).abs() < 1e-10);
953 }
954
955 #[test]
958 fn test_scaled_kernel() {
959 use crate::tensor_kernels::LinearKernel;
960
961 let base = LinearKernel::new();
962 let scaled = ScaledKernel::new(base, 2.0).unwrap();
963 assert_eq!(scaled.name(), "Scaled");
964
965 let x = vec![1.0, 2.0, 3.0];
966 let y = vec![4.0, 5.0, 6.0];
967
968 let sim = scaled.compute(&x, &y).unwrap();
970 assert!((sim - 64.0).abs() < 1e-10);
971 }
972
973 #[test]
974 fn test_scaled_kernel_invalid() {
975 use crate::tensor_kernels::LinearKernel;
976
977 let base = LinearKernel::new();
978 let result = ScaledKernel::new(base, 0.0);
979 assert!(result.is_err());
980 }
981
982 #[test]
983 fn test_scaled_kernel_psd() {
984 use crate::tensor_kernels::LinearKernel;
985
986 let base = LinearKernel::new();
987 let scaled = ScaledKernel::new(base, 2.0).unwrap();
988 assert!(scaled.is_psd());
989 }
990
991 #[test]
994 fn test_ard_kernels_symmetry() {
995 let kernels: Vec<Box<dyn Kernel>> = vec![
996 Box::new(ArdRbfKernel::new(vec![1.0, 2.0]).unwrap()),
997 Box::new(ArdMaternKernel::nu_3_2(vec![1.0, 2.0]).unwrap()),
998 Box::new(ArdRationalQuadraticKernel::new(vec![1.0, 2.0], 2.0).unwrap()),
999 ];
1000
1001 let x = vec![1.0, 2.0];
1002 let y = vec![3.0, 4.0];
1003
1004 for kernel in kernels {
1005 let k_xy = kernel.compute(&x, &y).unwrap();
1006 let k_yx = kernel.compute(&y, &x).unwrap();
1007 assert!(
1008 (k_xy - k_yx).abs() < 1e-10,
1009 "{} not symmetric",
1010 kernel.name()
1011 );
1012 }
1013 }
1014
1015 #[test]
1016 fn test_utility_kernels_symmetry() {
1017 let kernels: Vec<Box<dyn Kernel>> = vec![
1018 Box::new(WhiteNoiseKernel::new(0.1).unwrap()),
1019 Box::new(ConstantKernel::new(1.0).unwrap()),
1020 Box::new(DotProductKernel::simple()),
1021 ];
1022
1023 let x = vec![1.0, 2.0, 3.0];
1024 let y = vec![4.0, 5.0, 6.0];
1025
1026 for kernel in kernels {
1027 let k_xy = kernel.compute(&x, &y).unwrap();
1028 let k_yx = kernel.compute(&y, &x).unwrap();
1029 assert!(
1030 (k_xy - k_yx).abs() < 1e-10,
1031 "{} not symmetric",
1032 kernel.name()
1033 );
1034 }
1035 }
1036}