1#![allow(clippy::needless_range_loop)]
3
4use crate::error::{KernelError, Result};
22use crate::types::Kernel;
23use std::f64::consts::PI;
24
25#[derive(Debug, Clone)]
27pub struct SpectralComponent {
28 pub weight: f64,
30 pub mean: Vec<f64>,
32 pub variance: Vec<f64>,
34}
35
36impl SpectralComponent {
37 pub fn new(weight: f64, mean: Vec<f64>, variance: Vec<f64>) -> Result<Self> {
44 if weight <= 0.0 {
45 return Err(KernelError::InvalidParameter {
46 parameter: "weight".to_string(),
47 value: weight.to_string(),
48 reason: "weight must be positive".to_string(),
49 });
50 }
51
52 if mean.len() != variance.len() {
53 return Err(KernelError::InvalidParameter {
54 parameter: "mean/variance".to_string(),
55 value: format!(
56 "mean.len()={}, variance.len()={}",
57 mean.len(),
58 variance.len()
59 ),
60 reason: "mean and variance must have same length".to_string(),
61 });
62 }
63
64 if mean.is_empty() {
65 return Err(KernelError::InvalidParameter {
66 parameter: "mean".to_string(),
67 value: "[]".to_string(),
68 reason: "must have at least one dimension".to_string(),
69 });
70 }
71
72 for (i, &v) in variance.iter().enumerate() {
73 if v <= 0.0 {
74 return Err(KernelError::InvalidParameter {
75 parameter: format!("variance[{}]", i),
76 value: v.to_string(),
77 reason: "variance must be positive".to_string(),
78 });
79 }
80 }
81
82 Ok(Self {
83 weight,
84 mean,
85 variance,
86 })
87 }
88
89 pub fn new_1d(weight: f64, mean: f64, variance: f64) -> Result<Self> {
91 Self::new(weight, vec![mean], vec![variance])
92 }
93
94 pub fn ndim(&self) -> usize {
96 self.mean.len()
97 }
98}
99
100#[derive(Debug, Clone)]
112pub struct SpectralMixtureKernel {
113 components: Vec<SpectralComponent>,
115 ndim: usize,
117}
118
119impl SpectralMixtureKernel {
120 pub fn new(components: Vec<SpectralComponent>) -> Result<Self> {
138 if components.is_empty() {
139 return Err(KernelError::InvalidParameter {
140 parameter: "components".to_string(),
141 value: "[]".to_string(),
142 reason: "must have at least one component".to_string(),
143 });
144 }
145
146 let ndim = components[0].ndim();
147 for (i, comp) in components.iter().enumerate() {
148 if comp.ndim() != ndim {
149 return Err(KernelError::InvalidParameter {
150 parameter: format!("components[{}]", i),
151 value: format!("ndim={}", comp.ndim()),
152 reason: format!("all components must have {} dimensions", ndim),
153 });
154 }
155 }
156
157 Ok(Self { components, ndim })
158 }
159
160 pub fn new_1d(frequencies: Vec<(f64, f64, f64)>) -> Result<Self> {
165 let components: Result<Vec<_>> = frequencies
166 .into_iter()
167 .map(|(w, m, v)| SpectralComponent::new_1d(w, m, v))
168 .collect();
169 Self::new(components?)
170 }
171
172 pub fn components(&self) -> &[SpectralComponent] {
174 &self.components
175 }
176
177 pub fn num_components(&self) -> usize {
179 self.components.len()
180 }
181
182 pub fn ndim(&self) -> usize {
184 self.ndim
185 }
186
187 fn compute_component(&self, comp: &SpectralComponent, tau: &[f64]) -> f64 {
189 let mut exp_term = 0.0;
190 let mut cos_term = 0.0;
191
192 for d in 0..self.ndim {
193 let tau_d = tau[d];
194 exp_term += tau_d * tau_d * comp.variance[d];
195 cos_term += tau_d * comp.mean[d];
196 }
197
198 comp.weight * (-2.0 * PI * PI * exp_term).exp() * (2.0 * PI * cos_term).cos()
200 }
201}
202
203impl Kernel for SpectralMixtureKernel {
204 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
205 if x.len() != self.ndim {
206 return Err(KernelError::DimensionMismatch {
207 expected: vec![self.ndim],
208 got: vec![x.len()],
209 context: "Spectral Mixture kernel".to_string(),
210 });
211 }
212 if y.len() != self.ndim {
213 return Err(KernelError::DimensionMismatch {
214 expected: vec![self.ndim],
215 got: vec![y.len()],
216 context: "Spectral Mixture kernel".to_string(),
217 });
218 }
219
220 let tau: Vec<f64> = x.iter().zip(y.iter()).map(|(a, b)| a - b).collect();
222
223 let mut result = 0.0;
225 for comp in &self.components {
226 result += self.compute_component(comp, &tau);
227 }
228
229 Ok(result)
230 }
231
232 fn name(&self) -> &str {
233 "SpectralMixture"
234 }
235}
236
237#[derive(Debug, Clone)]
244pub struct ExpSineSquaredKernel {
245 period: f64,
247 length_scale: f64,
249}
250
251impl ExpSineSquaredKernel {
252 pub fn new(period: f64, length_scale: f64) -> Result<Self> {
258 if period <= 0.0 {
259 return Err(KernelError::InvalidParameter {
260 parameter: "period".to_string(),
261 value: period.to_string(),
262 reason: "period must be positive".to_string(),
263 });
264 }
265 if length_scale <= 0.0 {
266 return Err(KernelError::InvalidParameter {
267 parameter: "length_scale".to_string(),
268 value: length_scale.to_string(),
269 reason: "length_scale must be positive".to_string(),
270 });
271 }
272 Ok(Self {
273 period,
274 length_scale,
275 })
276 }
277
278 pub fn period(&self) -> f64 {
280 self.period
281 }
282
283 pub fn length_scale(&self) -> f64 {
285 self.length_scale
286 }
287}
288
289impl Kernel for ExpSineSquaredKernel {
290 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
291 if x.len() != y.len() {
292 return Err(KernelError::DimensionMismatch {
293 expected: vec![x.len()],
294 got: vec![y.len()],
295 context: "ExpSineSquared kernel".to_string(),
296 });
297 }
298
299 let dist: f64 = x
301 .iter()
302 .zip(y.iter())
303 .map(|(a, b)| (a - b) * (a - b))
304 .sum::<f64>()
305 .sqrt();
306
307 let sin_term = (PI * dist / self.period).sin();
308 let result = (-2.0 * sin_term * sin_term / (self.length_scale * self.length_scale)).exp();
309
310 Ok(result)
311 }
312
313 fn name(&self) -> &str {
314 "ExpSineSquared"
315 }
316}
317
318#[derive(Debug, Clone)]
326pub struct LocallyPeriodicKernel {
327 period: f64,
329 periodic_length_scale: f64,
331 rbf_length_scale: f64,
333}
334
335impl LocallyPeriodicKernel {
336 pub fn new(period: f64, periodic_length_scale: f64, rbf_length_scale: f64) -> Result<Self> {
343 if period <= 0.0 {
344 return Err(KernelError::InvalidParameter {
345 parameter: "period".to_string(),
346 value: period.to_string(),
347 reason: "period must be positive".to_string(),
348 });
349 }
350 if periodic_length_scale <= 0.0 {
351 return Err(KernelError::InvalidParameter {
352 parameter: "periodic_length_scale".to_string(),
353 value: periodic_length_scale.to_string(),
354 reason: "periodic_length_scale must be positive".to_string(),
355 });
356 }
357 if rbf_length_scale <= 0.0 {
358 return Err(KernelError::InvalidParameter {
359 parameter: "rbf_length_scale".to_string(),
360 value: rbf_length_scale.to_string(),
361 reason: "rbf_length_scale must be positive".to_string(),
362 });
363 }
364 Ok(Self {
365 period,
366 periodic_length_scale,
367 rbf_length_scale,
368 })
369 }
370
371 pub fn period(&self) -> f64 {
373 self.period
374 }
375
376 pub fn periodic_length_scale(&self) -> f64 {
378 self.periodic_length_scale
379 }
380
381 pub fn rbf_length_scale(&self) -> f64 {
383 self.rbf_length_scale
384 }
385}
386
387impl Kernel for LocallyPeriodicKernel {
388 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
389 if x.len() != y.len() {
390 return Err(KernelError::DimensionMismatch {
391 expected: vec![x.len()],
392 got: vec![y.len()],
393 context: "Locally Periodic kernel".to_string(),
394 });
395 }
396
397 let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
399 let dist = sq_dist.sqrt();
400
401 let rbf = (-0.5 * sq_dist / (self.rbf_length_scale * self.rbf_length_scale)).exp();
403
404 let sin_term = (PI * dist / self.period).sin();
406 let periodic = (-2.0 * sin_term * sin_term
407 / (self.periodic_length_scale * self.periodic_length_scale))
408 .exp();
409
410 Ok(rbf * periodic)
411 }
412
413 fn name(&self) -> &str {
414 "LocallyPeriodic"
415 }
416}
417
418#[derive(Debug, Clone)]
424pub struct RbfLinearKernel {
425 length_scale: f64,
427 variance: f64,
429}
430
431impl RbfLinearKernel {
432 pub fn new(length_scale: f64, variance: f64) -> Result<Self> {
434 if length_scale <= 0.0 {
435 return Err(KernelError::InvalidParameter {
436 parameter: "length_scale".to_string(),
437 value: length_scale.to_string(),
438 reason: "length_scale must be positive".to_string(),
439 });
440 }
441 if variance <= 0.0 {
442 return Err(KernelError::InvalidParameter {
443 parameter: "variance".to_string(),
444 value: variance.to_string(),
445 reason: "variance must be positive".to_string(),
446 });
447 }
448 Ok(Self {
449 length_scale,
450 variance,
451 })
452 }
453
454 pub fn length_scale(&self) -> f64 {
456 self.length_scale
457 }
458
459 pub fn variance(&self) -> f64 {
461 self.variance
462 }
463}
464
465impl Kernel for RbfLinearKernel {
466 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
467 if x.len() != y.len() {
468 return Err(KernelError::DimensionMismatch {
469 expected: vec![x.len()],
470 got: vec![y.len()],
471 context: "RBF-Linear kernel".to_string(),
472 });
473 }
474
475 let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
477
478 let rbf = (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp();
480
481 let dot: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
483 let linear = self.variance * dot;
484
485 Ok(rbf * linear)
486 }
487
488 fn name(&self) -> &str {
489 "RBF-Linear"
490 }
491
492 fn is_psd(&self) -> bool {
493 true
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
505 fn test_spectral_component_1d() {
506 let comp = SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap();
507 assert!((comp.weight - 1.0).abs() < 1e-10);
508 assert_eq!(comp.ndim(), 1);
509 }
510
511 #[test]
512 fn test_spectral_component_multidim() {
513 let comp = SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01, 0.02]).unwrap();
514 assert_eq!(comp.ndim(), 2);
515 }
516
517 #[test]
518 fn test_spectral_component_invalid_weight() {
519 assert!(SpectralComponent::new_1d(0.0, 0.5, 0.1).is_err());
520 assert!(SpectralComponent::new_1d(-1.0, 0.5, 0.1).is_err());
521 }
522
523 #[test]
524 fn test_spectral_component_invalid_variance() {
525 assert!(SpectralComponent::new_1d(1.0, 0.5, 0.0).is_err());
526 assert!(SpectralComponent::new_1d(1.0, 0.5, -0.1).is_err());
527 }
528
529 #[test]
530 fn test_spectral_component_mismatched_dims() {
531 assert!(SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01]).is_err());
532 }
533
534 #[test]
537 fn test_spectral_mixture_kernel_single_component() {
538 let components = vec![SpectralComponent::new_1d(1.0, 0.0, 0.1).unwrap()];
539 let kernel = SpectralMixtureKernel::new(components).unwrap();
540 assert_eq!(kernel.name(), "SpectralMixture");
541 assert_eq!(kernel.num_components(), 1);
542
543 let x = vec![0.0];
544 let y = vec![0.0];
545 let sim = kernel.compute(&x, &y).unwrap();
546 assert!((sim - 1.0).abs() < 1e-10);
548 }
549
550 #[test]
551 fn test_spectral_mixture_kernel_multiple_components() {
552 let components = vec![
553 SpectralComponent::new_1d(0.5, 0.1, 0.01).unwrap(),
554 SpectralComponent::new_1d(0.5, 1.0, 0.1).unwrap(),
555 ];
556 let kernel = SpectralMixtureKernel::new(components).unwrap();
557 assert_eq!(kernel.num_components(), 2);
558
559 let x = vec![0.0];
560 let y = vec![0.0];
561 let sim = kernel.compute(&x, &y).unwrap();
562 assert!((sim - 1.0).abs() < 1e-10);
564 }
565
566 #[test]
567 fn test_spectral_mixture_kernel_1d_convenience() {
568 let kernel =
569 SpectralMixtureKernel::new_1d(vec![(1.0, 0.5, 0.1), (0.5, 1.0, 0.05)]).unwrap();
570 assert_eq!(kernel.num_components(), 2);
571 assert_eq!(kernel.ndim(), 1);
572 }
573
574 #[test]
575 fn test_spectral_mixture_kernel_periodicity() {
576 let freq = 0.25; let components = vec![SpectralComponent::new_1d(1.0, freq, 0.0001).unwrap()];
582 let kernel = SpectralMixtureKernel::new(components).unwrap();
583
584 let x = vec![0.0];
585 let y_period = vec![4.0]; let y_half = vec![2.0]; let sim_period = kernel.compute(&x, &y_period).unwrap();
589 let sim_half = kernel.compute(&x, &y_half).unwrap();
590
591 assert!(
594 sim_period > sim_half,
595 "Period value {} should exceed half-period value {}",
596 sim_period,
597 sim_half
598 );
599 assert!(
601 sim_period > 0.5,
602 "Period value {} should be > 0.5",
603 sim_period
604 );
605 }
606
607 #[test]
608 fn test_spectral_mixture_kernel_symmetry() {
609 let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()];
610 let kernel = SpectralMixtureKernel::new(components).unwrap();
611
612 let x = vec![1.0];
613 let y = vec![2.0];
614
615 let k_xy = kernel.compute(&x, &y).unwrap();
616 let k_yx = kernel.compute(&y, &x).unwrap();
617 assert!((k_xy - k_yx).abs() < 1e-10);
618 }
619
620 #[test]
621 fn test_spectral_mixture_kernel_empty_components() {
622 let result = SpectralMixtureKernel::new(vec![]);
623 assert!(result.is_err());
624 }
625
626 #[test]
627 fn test_spectral_mixture_kernel_dimension_mismatch() {
628 let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()];
629 let kernel = SpectralMixtureKernel::new(components).unwrap();
630
631 let x = vec![0.0, 0.0]; let y = vec![0.0]; assert!(kernel.compute(&x, &y).is_err());
635 }
636
637 #[test]
640 fn test_exp_sine_squared_kernel_basic() {
641 let kernel = ExpSineSquaredKernel::new(10.0, 1.0).unwrap();
642 assert_eq!(kernel.name(), "ExpSineSquared");
643
644 let x = vec![0.0];
645 let y = vec![0.0];
646 let sim = kernel.compute(&x, &y).unwrap();
647 assert!((sim - 1.0).abs() < 1e-10);
648 }
649
650 #[test]
651 fn test_exp_sine_squared_kernel_periodicity() {
652 let period = 10.0;
653 let kernel = ExpSineSquaredKernel::new(period, 1.0).unwrap();
654
655 let x = vec![0.0];
656 let y1 = vec![period]; let y2 = vec![2.0 * period]; let sim1 = kernel.compute(&x, &y1).unwrap();
660 let sim2 = kernel.compute(&x, &y2).unwrap();
661
662 assert!(sim1 > 0.99);
664 assert!(sim2 > 0.99);
665 }
666
667 #[test]
668 fn test_exp_sine_squared_kernel_invalid() {
669 assert!(ExpSineSquaredKernel::new(0.0, 1.0).is_err());
670 assert!(ExpSineSquaredKernel::new(10.0, 0.0).is_err());
671 }
672
673 #[test]
676 fn test_locally_periodic_kernel_basic() {
677 let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 100.0).unwrap();
678 assert_eq!(kernel.name(), "LocallyPeriodic");
679
680 let x = vec![0.0];
681 let sim = kernel.compute(&x, &x).unwrap();
682 assert!((sim - 1.0).abs() < 1e-10);
683 }
684
685 #[test]
686 fn test_locally_periodic_kernel_decay() {
687 let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 5.0).unwrap();
689
690 let x = vec![0.0];
691 let y_near = vec![10.0]; let y_far = vec![100.0]; let sim_near = kernel.compute(&x, &y_near).unwrap();
695 let sim_far = kernel.compute(&x, &y_far).unwrap();
696
697 assert!(sim_near > sim_far);
699 }
700
701 #[test]
702 fn test_locally_periodic_kernel_invalid() {
703 assert!(LocallyPeriodicKernel::new(0.0, 1.0, 1.0).is_err());
704 assert!(LocallyPeriodicKernel::new(10.0, 0.0, 1.0).is_err());
705 assert!(LocallyPeriodicKernel::new(10.0, 1.0, 0.0).is_err());
706 }
707
708 #[test]
711 fn test_rbf_linear_kernel_basic() {
712 let kernel = RbfLinearKernel::new(1.0, 1.0).unwrap();
713 assert_eq!(kernel.name(), "RBF-Linear");
714 assert!(kernel.is_psd());
715
716 let x = vec![1.0, 2.0];
717 let y = vec![1.0, 2.0];
718
719 let sim = kernel.compute(&x, &y).unwrap();
720 assert!((sim - 5.0).abs() < 1e-10);
722 }
723
724 #[test]
725 fn test_rbf_linear_kernel_symmetry() {
726 let kernel = RbfLinearKernel::new(1.0, 1.0).unwrap();
727
728 let x = vec![1.0, 2.0];
729 let y = vec![3.0, 4.0];
730
731 let k_xy = kernel.compute(&x, &y).unwrap();
732 let k_yx = kernel.compute(&y, &x).unwrap();
733 assert!((k_xy - k_yx).abs() < 1e-10);
734 }
735
736 #[test]
737 fn test_rbf_linear_kernel_invalid() {
738 assert!(RbfLinearKernel::new(0.0, 1.0).is_err());
739 assert!(RbfLinearKernel::new(1.0, 0.0).is_err());
740 }
741
742 #[test]
745 fn test_spectral_kernels_symmetry() {
746 let kernels: Vec<Box<dyn Kernel>> = vec![
747 Box::new(
748 SpectralMixtureKernel::new(vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()])
749 .unwrap(),
750 ),
751 Box::new(ExpSineSquaredKernel::new(10.0, 1.0).unwrap()),
752 Box::new(LocallyPeriodicKernel::new(10.0, 1.0, 10.0).unwrap()),
753 Box::new(RbfLinearKernel::new(1.0, 1.0).unwrap()),
754 ];
755
756 let x = vec![1.0];
757 let y = vec![2.0];
758
759 for kernel in kernels {
760 let k_xy = kernel.compute(&x, &y).unwrap();
761 let k_yx = kernel.compute(&y, &x).unwrap();
762 assert!(
763 (k_xy - k_yx).abs() < 1e-10,
764 "{} not symmetric",
765 kernel.name()
766 );
767 }
768 }
769}