Skip to main content

tensorlogic_sklears_kernels/
spectral_kernel.rs

1// Allow needless_range_loop for spectral computations which are clearer with indexed loops
2#![allow(clippy::needless_range_loop)]
3
4//! Spectral Mixture kernels for discovering latent periodic components.
5//!
6//! The Spectral Mixture (SM) kernel models a signal as a mixture of
7//! Gaussian components in the spectral (frequency) domain. This allows
8//! it to automatically discover multiple periodic patterns in data.
9//!
10//! ## Key Features
11//!
12//! - **Automatic pattern discovery**: Learns multiple periodic components
13//! - **Flexible modeling**: Can approximate any stationary kernel
14//! - **Interpretable**: Each component has clear frequency and length scale
15//!
16//! ## Reference
17//!
18//! Wilson, A. G., & Adams, R. P. (2013). "Gaussian Process Kernels for
19//! Pattern Discovery and Extrapolation." ICML.
20
21use crate::error::{KernelError, Result};
22use crate::types::Kernel;
23use std::f64::consts::PI;
24
25/// A single spectral component with weight, mean frequency, and variance.
26#[derive(Debug, Clone)]
27pub struct SpectralComponent {
28    /// Weight (mixture proportion)
29    pub weight: f64,
30    /// Mean frequency (per-dimension)
31    pub mean: Vec<f64>,
32    /// Variance (per-dimension, controls bandwidth)
33    pub variance: Vec<f64>,
34}
35
36impl SpectralComponent {
37    /// Create a new spectral component.
38    ///
39    /// # Arguments
40    /// * `weight` - Mixture weight (must be positive)
41    /// * `mean` - Mean frequency for each dimension
42    /// * `variance` - Variance for each dimension (must be positive)
43    pub fn new(weight: f64, mean: Vec<f64>, variance: Vec<f64>) -> Result<Self> {
44        if weight <= 0.0 {
45            return Err(KernelError::InvalidParameter {
46                parameter: "weight".to_string(),
47                value: weight.to_string(),
48                reason: "weight must be positive".to_string(),
49            });
50        }
51
52        if mean.len() != variance.len() {
53            return Err(KernelError::InvalidParameter {
54                parameter: "mean/variance".to_string(),
55                value: format!(
56                    "mean.len()={}, variance.len()={}",
57                    mean.len(),
58                    variance.len()
59                ),
60                reason: "mean and variance must have same length".to_string(),
61            });
62        }
63
64        if mean.is_empty() {
65            return Err(KernelError::InvalidParameter {
66                parameter: "mean".to_string(),
67                value: "[]".to_string(),
68                reason: "must have at least one dimension".to_string(),
69            });
70        }
71
72        for (i, &v) in variance.iter().enumerate() {
73            if v <= 0.0 {
74                return Err(KernelError::InvalidParameter {
75                    parameter: format!("variance[{}]", i),
76                    value: v.to_string(),
77                    reason: "variance must be positive".to_string(),
78                });
79            }
80        }
81
82        Ok(Self {
83            weight,
84            mean,
85            variance,
86        })
87    }
88
89    /// Create a 1D spectral component.
90    pub fn new_1d(weight: f64, mean: f64, variance: f64) -> Result<Self> {
91        Self::new(weight, vec![mean], vec![variance])
92    }
93
94    /// Get the number of dimensions.
95    pub fn ndim(&self) -> usize {
96        self.mean.len()
97    }
98}
99
100/// Spectral Mixture (SM) kernel.
101///
102/// K(x, y) = Σ_q w_q * exp(-2π² * Σ_d (x_d - y_d)² * v_q,d) * cos(2π * Σ_d (x_d - y_d) * μ_q,d)
103///
104/// where:
105/// - w_q is the weight of component q
106/// - μ_q,d is the mean frequency of component q in dimension d
107/// - v_q,d is the variance of component q in dimension d
108///
109/// This kernel can approximate any stationary covariance function arbitrarily
110/// well as the number of components increases.
111#[derive(Debug, Clone)]
112pub struct SpectralMixtureKernel {
113    /// Spectral components
114    components: Vec<SpectralComponent>,
115    /// Number of input dimensions
116    ndim: usize,
117}
118
119impl SpectralMixtureKernel {
120    /// Create a new Spectral Mixture kernel.
121    ///
122    /// # Arguments
123    /// * `components` - List of spectral components
124    ///
125    /// # Example
126    /// ```rust
127    /// use tensorlogic_sklears_kernels::spectral_kernel::{SpectralMixtureKernel, SpectralComponent};
128    /// use tensorlogic_sklears_kernels::Kernel;
129    ///
130    /// // Create a kernel with two periodic components
131    /// let components = vec![
132    ///     SpectralComponent::new_1d(1.0, 0.1, 0.01).unwrap(),  // Low frequency
133    ///     SpectralComponent::new_1d(0.5, 1.0, 0.1).unwrap(),   // High frequency
134    /// ];
135    /// let kernel = SpectralMixtureKernel::new(components).unwrap();
136    /// ```
137    pub fn new(components: Vec<SpectralComponent>) -> Result<Self> {
138        if components.is_empty() {
139            return Err(KernelError::InvalidParameter {
140                parameter: "components".to_string(),
141                value: "[]".to_string(),
142                reason: "must have at least one component".to_string(),
143            });
144        }
145
146        let ndim = components[0].ndim();
147        for (i, comp) in components.iter().enumerate() {
148            if comp.ndim() != ndim {
149                return Err(KernelError::InvalidParameter {
150                    parameter: format!("components[{}]", i),
151                    value: format!("ndim={}", comp.ndim()),
152                    reason: format!("all components must have {} dimensions", ndim),
153                });
154            }
155        }
156
157        Ok(Self { components, ndim })
158    }
159
160    /// Create a simple 1D spectral mixture kernel with given frequencies.
161    ///
162    /// # Arguments
163    /// * `frequencies` - List of (weight, mean_frequency, variance) tuples
164    pub fn new_1d(frequencies: Vec<(f64, f64, f64)>) -> Result<Self> {
165        let components: Result<Vec<_>> = frequencies
166            .into_iter()
167            .map(|(w, m, v)| SpectralComponent::new_1d(w, m, v))
168            .collect();
169        Self::new(components?)
170    }
171
172    /// Get the components.
173    pub fn components(&self) -> &[SpectralComponent] {
174        &self.components
175    }
176
177    /// Get the number of components.
178    pub fn num_components(&self) -> usize {
179        self.components.len()
180    }
181
182    /// Get the number of dimensions.
183    pub fn ndim(&self) -> usize {
184        self.ndim
185    }
186
187    /// Compute the contribution of a single component.
188    fn compute_component(&self, comp: &SpectralComponent, tau: &[f64]) -> f64 {
189        let mut exp_term = 0.0;
190        let mut cos_term = 0.0;
191
192        for d in 0..self.ndim {
193            let tau_d = tau[d];
194            exp_term += tau_d * tau_d * comp.variance[d];
195            cos_term += tau_d * comp.mean[d];
196        }
197
198        // K_q = w_q * exp(-2π² * exp_term) * cos(2π * cos_term)
199        comp.weight * (-2.0 * PI * PI * exp_term).exp() * (2.0 * PI * cos_term).cos()
200    }
201}
202
203impl Kernel for SpectralMixtureKernel {
204    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
205        if x.len() != self.ndim {
206            return Err(KernelError::DimensionMismatch {
207                expected: vec![self.ndim],
208                got: vec![x.len()],
209                context: "Spectral Mixture kernel".to_string(),
210            });
211        }
212        if y.len() != self.ndim {
213            return Err(KernelError::DimensionMismatch {
214                expected: vec![self.ndim],
215                got: vec![y.len()],
216                context: "Spectral Mixture kernel".to_string(),
217            });
218        }
219
220        // Compute tau = x - y
221        let tau: Vec<f64> = x.iter().zip(y.iter()).map(|(a, b)| a - b).collect();
222
223        // Sum over all components
224        let mut result = 0.0;
225        for comp in &self.components {
226            result += self.compute_component(comp, &tau);
227        }
228
229        Ok(result)
230    }
231
232    fn name(&self) -> &str {
233        "SpectralMixture"
234    }
235}
236
237/// Exponential Sine Squared kernel (also known as Periodic kernel).
238///
239/// K(x, y) = exp(-2 * sin²(π * |x - y| / period) / l²)
240///
241/// This is equivalent to the ExpSineSquared kernel in scikit-learn.
242/// It models functions that repeat exactly over a specified period.
243#[derive(Debug, Clone)]
244pub struct ExpSineSquaredKernel {
245    /// Period of the periodic pattern
246    period: f64,
247    /// Length scale (controls smoothness within period)
248    length_scale: f64,
249}
250
251impl ExpSineSquaredKernel {
252    /// Create a new Exponential Sine Squared kernel.
253    ///
254    /// # Arguments
255    /// * `period` - Period of the pattern (must be positive)
256    /// * `length_scale` - Length scale parameter (must be positive)
257    pub fn new(period: f64, length_scale: f64) -> Result<Self> {
258        if period <= 0.0 {
259            return Err(KernelError::InvalidParameter {
260                parameter: "period".to_string(),
261                value: period.to_string(),
262                reason: "period must be positive".to_string(),
263            });
264        }
265        if length_scale <= 0.0 {
266            return Err(KernelError::InvalidParameter {
267                parameter: "length_scale".to_string(),
268                value: length_scale.to_string(),
269                reason: "length_scale must be positive".to_string(),
270            });
271        }
272        Ok(Self {
273            period,
274            length_scale,
275        })
276    }
277
278    /// Get the period.
279    pub fn period(&self) -> f64 {
280        self.period
281    }
282
283    /// Get the length scale.
284    pub fn length_scale(&self) -> f64 {
285        self.length_scale
286    }
287}
288
289impl Kernel for ExpSineSquaredKernel {
290    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
291        if x.len() != y.len() {
292            return Err(KernelError::DimensionMismatch {
293                expected: vec![x.len()],
294                got: vec![y.len()],
295                context: "ExpSineSquared kernel".to_string(),
296            });
297        }
298
299        // Compute Euclidean distance
300        let dist: f64 = x
301            .iter()
302            .zip(y.iter())
303            .map(|(a, b)| (a - b) * (a - b))
304            .sum::<f64>()
305            .sqrt();
306
307        let sin_term = (PI * dist / self.period).sin();
308        let result = (-2.0 * sin_term * sin_term / (self.length_scale * self.length_scale)).exp();
309
310        Ok(result)
311    }
312
313    fn name(&self) -> &str {
314        "ExpSineSquared"
315    }
316}
317
318/// Locally Periodic kernel: RBF × Periodic
319///
320/// K(x, y) = k_rbf(x, y) * k_periodic(x, y)
321///
322/// Models functions that are periodic but whose amplitude varies smoothly.
323/// The RBF component controls the locality (how quickly periodicity decays),
324/// while the periodic component captures the repetitive structure.
325#[derive(Debug, Clone)]
326pub struct LocallyPeriodicKernel {
327    /// Period of the periodic component
328    period: f64,
329    /// Length scale for the periodic component
330    periodic_length_scale: f64,
331    /// Length scale for the RBF component (controls locality)
332    rbf_length_scale: f64,
333}
334
335impl LocallyPeriodicKernel {
336    /// Create a new Locally Periodic kernel.
337    ///
338    /// # Arguments
339    /// * `period` - Period of the repetitive pattern
340    /// * `periodic_length_scale` - Length scale within each period
341    /// * `rbf_length_scale` - How quickly the periodic pattern decays
342    pub fn new(period: f64, periodic_length_scale: f64, rbf_length_scale: f64) -> Result<Self> {
343        if period <= 0.0 {
344            return Err(KernelError::InvalidParameter {
345                parameter: "period".to_string(),
346                value: period.to_string(),
347                reason: "period must be positive".to_string(),
348            });
349        }
350        if periodic_length_scale <= 0.0 {
351            return Err(KernelError::InvalidParameter {
352                parameter: "periodic_length_scale".to_string(),
353                value: periodic_length_scale.to_string(),
354                reason: "periodic_length_scale must be positive".to_string(),
355            });
356        }
357        if rbf_length_scale <= 0.0 {
358            return Err(KernelError::InvalidParameter {
359                parameter: "rbf_length_scale".to_string(),
360                value: rbf_length_scale.to_string(),
361                reason: "rbf_length_scale must be positive".to_string(),
362            });
363        }
364        Ok(Self {
365            period,
366            periodic_length_scale,
367            rbf_length_scale,
368        })
369    }
370
371    /// Get the period.
372    pub fn period(&self) -> f64 {
373        self.period
374    }
375
376    /// Get the periodic length scale.
377    pub fn periodic_length_scale(&self) -> f64 {
378        self.periodic_length_scale
379    }
380
381    /// Get the RBF length scale.
382    pub fn rbf_length_scale(&self) -> f64 {
383        self.rbf_length_scale
384    }
385}
386
387impl Kernel for LocallyPeriodicKernel {
388    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
389        if x.len() != y.len() {
390            return Err(KernelError::DimensionMismatch {
391                expected: vec![x.len()],
392                got: vec![y.len()],
393                context: "Locally Periodic kernel".to_string(),
394            });
395        }
396
397        // Compute squared distance and distance
398        let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
399        let dist = sq_dist.sqrt();
400
401        // RBF component: exp(-0.5 * d² / l²)
402        let rbf = (-0.5 * sq_dist / (self.rbf_length_scale * self.rbf_length_scale)).exp();
403
404        // Periodic component: exp(-2 * sin²(π * d / p) / l²)
405        let sin_term = (PI * dist / self.period).sin();
406        let periodic = (-2.0 * sin_term * sin_term
407            / (self.periodic_length_scale * self.periodic_length_scale))
408            .exp();
409
410        Ok(rbf * periodic)
411    }
412
413    fn name(&self) -> &str {
414        "LocallyPeriodic"
415    }
416}
417
418/// Product of RBF and Linear kernels.
419///
420/// K(x, y) = k_rbf(x, y) * k_linear(x, y)
421///
422/// Models functions with smoothly varying linear trends.
423#[derive(Debug, Clone)]
424pub struct RbfLinearKernel {
425    /// RBF length scale
426    length_scale: f64,
427    /// Linear kernel variance
428    variance: f64,
429}
430
431impl RbfLinearKernel {
432    /// Create a new RBF × Linear kernel.
433    pub fn new(length_scale: f64, variance: f64) -> Result<Self> {
434        if length_scale <= 0.0 {
435            return Err(KernelError::InvalidParameter {
436                parameter: "length_scale".to_string(),
437                value: length_scale.to_string(),
438                reason: "length_scale must be positive".to_string(),
439            });
440        }
441        if variance <= 0.0 {
442            return Err(KernelError::InvalidParameter {
443                parameter: "variance".to_string(),
444                value: variance.to_string(),
445                reason: "variance must be positive".to_string(),
446            });
447        }
448        Ok(Self {
449            length_scale,
450            variance,
451        })
452    }
453
454    /// Get the length scale.
455    pub fn length_scale(&self) -> f64 {
456        self.length_scale
457    }
458
459    /// Get the variance.
460    pub fn variance(&self) -> f64 {
461        self.variance
462    }
463}
464
465impl Kernel for RbfLinearKernel {
466    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
467        if x.len() != y.len() {
468            return Err(KernelError::DimensionMismatch {
469                expected: vec![x.len()],
470                got: vec![y.len()],
471                context: "RBF-Linear kernel".to_string(),
472            });
473        }
474
475        // Squared distance
476        let sq_dist: f64 = x.iter().zip(y.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
477
478        // RBF component
479        let rbf = (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp();
480
481        // Linear component (dot product)
482        let dot: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
483        let linear = self.variance * dot;
484
485        Ok(rbf * linear)
486    }
487
488    fn name(&self) -> &str {
489        "RBF-Linear"
490    }
491
492    fn is_psd(&self) -> bool {
493        // Product of PSD kernels is PSD
494        true
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    // ===== Spectral Component Tests =====
503
504    #[test]
505    fn test_spectral_component_1d() {
506        let comp = SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap();
507        assert!((comp.weight - 1.0).abs() < 1e-10);
508        assert_eq!(comp.ndim(), 1);
509    }
510
511    #[test]
512    fn test_spectral_component_multidim() {
513        let comp = SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01, 0.02]).unwrap();
514        assert_eq!(comp.ndim(), 2);
515    }
516
517    #[test]
518    fn test_spectral_component_invalid_weight() {
519        assert!(SpectralComponent::new_1d(0.0, 0.5, 0.1).is_err());
520        assert!(SpectralComponent::new_1d(-1.0, 0.5, 0.1).is_err());
521    }
522
523    #[test]
524    fn test_spectral_component_invalid_variance() {
525        assert!(SpectralComponent::new_1d(1.0, 0.5, 0.0).is_err());
526        assert!(SpectralComponent::new_1d(1.0, 0.5, -0.1).is_err());
527    }
528
529    #[test]
530    fn test_spectral_component_mismatched_dims() {
531        assert!(SpectralComponent::new(1.0, vec![0.1, 0.2], vec![0.01]).is_err());
532    }
533
534    // ===== Spectral Mixture Kernel Tests =====
535
536    #[test]
537    fn test_spectral_mixture_kernel_single_component() {
538        let components = vec![SpectralComponent::new_1d(1.0, 0.0, 0.1).unwrap()];
539        let kernel = SpectralMixtureKernel::new(components).unwrap();
540        assert_eq!(kernel.name(), "SpectralMixture");
541        assert_eq!(kernel.num_components(), 1);
542
543        let x = vec![0.0];
544        let y = vec![0.0];
545        let sim = kernel.compute(&x, &y).unwrap();
546        // At same point with mean=0: cos(0) = 1, exp(0) = 1
547        assert!((sim - 1.0).abs() < 1e-10);
548    }
549
550    #[test]
551    fn test_spectral_mixture_kernel_multiple_components() {
552        let components = vec![
553            SpectralComponent::new_1d(0.5, 0.1, 0.01).unwrap(),
554            SpectralComponent::new_1d(0.5, 1.0, 0.1).unwrap(),
555        ];
556        let kernel = SpectralMixtureKernel::new(components).unwrap();
557        assert_eq!(kernel.num_components(), 2);
558
559        let x = vec![0.0];
560        let y = vec![0.0];
561        let sim = kernel.compute(&x, &y).unwrap();
562        // At same point: should be sum of weights = 1.0
563        assert!((sim - 1.0).abs() < 1e-10);
564    }
565
566    #[test]
567    fn test_spectral_mixture_kernel_1d_convenience() {
568        let kernel =
569            SpectralMixtureKernel::new_1d(vec![(1.0, 0.5, 0.1), (0.5, 1.0, 0.05)]).unwrap();
570        assert_eq!(kernel.num_components(), 2);
571        assert_eq!(kernel.ndim(), 1);
572    }
573
574    #[test]
575    fn test_spectral_mixture_kernel_periodicity() {
576        // Single component with specific frequency should show periodicity
577        // Note: SM kernel has exponential decay * cos, so values won't be exactly 1
578        // at period boundaries, but the cosine component peaks at period multiples
579        let freq = 0.25; // Period = 1/freq = 4
580                         // Use very small variance to minimize exponential decay
581        let components = vec![SpectralComponent::new_1d(1.0, freq, 0.0001).unwrap()];
582        let kernel = SpectralMixtureKernel::new(components).unwrap();
583
584        let x = vec![0.0];
585        let y_period = vec![4.0]; // One period - cosine term = 1
586        let y_half = vec![2.0]; // Half period - cosine term = -1
587
588        let sim_period = kernel.compute(&x, &y_period).unwrap();
589        let sim_half = kernel.compute(&x, &y_half).unwrap();
590
591        // At exact period, cosine = 1, so value should be positive and near decay term
592        // At half period, cosine = -1, so value should be negative or lower
593        assert!(
594            sim_period > sim_half,
595            "Period value {} should exceed half-period value {}",
596            sim_period,
597            sim_half
598        );
599        // Period value should be reasonably high (accounting for some decay)
600        assert!(
601            sim_period > 0.5,
602            "Period value {} should be > 0.5",
603            sim_period
604        );
605    }
606
607    #[test]
608    fn test_spectral_mixture_kernel_symmetry() {
609        let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()];
610        let kernel = SpectralMixtureKernel::new(components).unwrap();
611
612        let x = vec![1.0];
613        let y = vec![2.0];
614
615        let k_xy = kernel.compute(&x, &y).unwrap();
616        let k_yx = kernel.compute(&y, &x).unwrap();
617        assert!((k_xy - k_yx).abs() < 1e-10);
618    }
619
620    #[test]
621    fn test_spectral_mixture_kernel_empty_components() {
622        let result = SpectralMixtureKernel::new(vec![]);
623        assert!(result.is_err());
624    }
625
626    #[test]
627    fn test_spectral_mixture_kernel_dimension_mismatch() {
628        let components = vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()];
629        let kernel = SpectralMixtureKernel::new(components).unwrap();
630
631        let x = vec![0.0, 0.0]; // 2D
632        let y = vec![0.0]; // 1D
633
634        assert!(kernel.compute(&x, &y).is_err());
635    }
636
637    // ===== Exponential Sine Squared Kernel Tests =====
638
639    #[test]
640    fn test_exp_sine_squared_kernel_basic() {
641        let kernel = ExpSineSquaredKernel::new(10.0, 1.0).unwrap();
642        assert_eq!(kernel.name(), "ExpSineSquared");
643
644        let x = vec![0.0];
645        let y = vec![0.0];
646        let sim = kernel.compute(&x, &y).unwrap();
647        assert!((sim - 1.0).abs() < 1e-10);
648    }
649
650    #[test]
651    fn test_exp_sine_squared_kernel_periodicity() {
652        let period = 10.0;
653        let kernel = ExpSineSquaredKernel::new(period, 1.0).unwrap();
654
655        let x = vec![0.0];
656        let y1 = vec![period]; // One period
657        let y2 = vec![2.0 * period]; // Two periods
658
659        let sim1 = kernel.compute(&x, &y1).unwrap();
660        let sim2 = kernel.compute(&x, &y2).unwrap();
661
662        // At exact period multiples, similarity should be very high
663        assert!(sim1 > 0.99);
664        assert!(sim2 > 0.99);
665    }
666
667    #[test]
668    fn test_exp_sine_squared_kernel_invalid() {
669        assert!(ExpSineSquaredKernel::new(0.0, 1.0).is_err());
670        assert!(ExpSineSquaredKernel::new(10.0, 0.0).is_err());
671    }
672
673    // ===== Locally Periodic Kernel Tests =====
674
675    #[test]
676    fn test_locally_periodic_kernel_basic() {
677        let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 100.0).unwrap();
678        assert_eq!(kernel.name(), "LocallyPeriodic");
679
680        let x = vec![0.0];
681        let sim = kernel.compute(&x, &x).unwrap();
682        assert!((sim - 1.0).abs() < 1e-10);
683    }
684
685    #[test]
686    fn test_locally_periodic_kernel_decay() {
687        // With small RBF length scale, periodicity should decay quickly
688        let kernel = LocallyPeriodicKernel::new(10.0, 1.0, 5.0).unwrap();
689
690        let x = vec![0.0];
691        let y_near = vec![10.0]; // One period, near
692        let y_far = vec![100.0]; // Ten periods, far
693
694        let sim_near = kernel.compute(&x, &y_near).unwrap();
695        let sim_far = kernel.compute(&x, &y_far).unwrap();
696
697        // Far point should have much lower similarity due to RBF decay
698        assert!(sim_near > sim_far);
699    }
700
701    #[test]
702    fn test_locally_periodic_kernel_invalid() {
703        assert!(LocallyPeriodicKernel::new(0.0, 1.0, 1.0).is_err());
704        assert!(LocallyPeriodicKernel::new(10.0, 0.0, 1.0).is_err());
705        assert!(LocallyPeriodicKernel::new(10.0, 1.0, 0.0).is_err());
706    }
707
708    // ===== RBF-Linear Kernel Tests =====
709
710    #[test]
711    fn test_rbf_linear_kernel_basic() {
712        let kernel = RbfLinearKernel::new(1.0, 1.0).unwrap();
713        assert_eq!(kernel.name(), "RBF-Linear");
714        assert!(kernel.is_psd());
715
716        let x = vec![1.0, 2.0];
717        let y = vec![1.0, 2.0];
718
719        let sim = kernel.compute(&x, &y).unwrap();
720        // dot(x,x) = 5, rbf(x,x) = 1, so result = 5
721        assert!((sim - 5.0).abs() < 1e-10);
722    }
723
724    #[test]
725    fn test_rbf_linear_kernel_symmetry() {
726        let kernel = RbfLinearKernel::new(1.0, 1.0).unwrap();
727
728        let x = vec![1.0, 2.0];
729        let y = vec![3.0, 4.0];
730
731        let k_xy = kernel.compute(&x, &y).unwrap();
732        let k_yx = kernel.compute(&y, &x).unwrap();
733        assert!((k_xy - k_yx).abs() < 1e-10);
734    }
735
736    #[test]
737    fn test_rbf_linear_kernel_invalid() {
738        assert!(RbfLinearKernel::new(0.0, 1.0).is_err());
739        assert!(RbfLinearKernel::new(1.0, 0.0).is_err());
740    }
741
742    // ===== Integration Tests =====
743
744    #[test]
745    fn test_spectral_kernels_symmetry() {
746        let kernels: Vec<Box<dyn Kernel>> = vec![
747            Box::new(
748                SpectralMixtureKernel::new(vec![SpectralComponent::new_1d(1.0, 0.5, 0.1).unwrap()])
749                    .unwrap(),
750            ),
751            Box::new(ExpSineSquaredKernel::new(10.0, 1.0).unwrap()),
752            Box::new(LocallyPeriodicKernel::new(10.0, 1.0, 10.0).unwrap()),
753            Box::new(RbfLinearKernel::new(1.0, 1.0).unwrap()),
754        ];
755
756        let x = vec![1.0];
757        let y = vec![2.0];
758
759        for kernel in kernels {
760            let k_xy = kernel.compute(&x, &y).unwrap();
761            let k_yx = kernel.compute(&y, &x).unwrap();
762            assert!(
763                (k_xy - k_yx).abs() < 1e-10,
764                "{} not symmetric",
765                kernel.name()
766            );
767        }
768    }
769}