Skip to main content

tensorlogic_sklears_kernels/
ard_kernel.rs

1//! Automatic Relevance Determination (ARD) kernels.
2//!
3//! ARD kernels learn a separate length scale for each input dimension,
4//! automatically determining the relevance of each feature. This is
5//! particularly useful for high-dimensional data where some features
6//! may be more important than others.
7//!
8//! ## Key Features
9//!
10//! - **Per-dimension length scales**: Each feature has its own length scale
11//! - **Feature selection**: Irrelevant features get large length scales (effectively ignored)
12//! - **Gradient support**: For hyperparameter optimization via gradient descent
13//!
14//! ## Example
15//!
16//! ```rust
17//! use tensorlogic_sklears_kernels::ard_kernel::{ArdRbfKernel, ArdMaternKernel};
18//! use tensorlogic_sklears_kernels::Kernel;
19//!
20//! // Create ARD RBF kernel with 3 features, each with its own length scale
21//! let length_scales = vec![1.0, 2.0, 0.5]; // Different relevance per dimension
22//! let kernel = ArdRbfKernel::new(length_scales.clone()).unwrap();
23//!
24//! let x = vec![1.0, 2.0, 3.0];
25//! let y = vec![1.5, 2.5, 3.5];
26//! let sim = kernel.compute(&x, &y).unwrap();
27//! ```
28
29use crate::error::{KernelError, Result};
30use crate::types::Kernel;
31
32/// ARD (Automatic Relevance Determination) RBF kernel.
33///
34/// K(x, y) = σ² * exp(-0.5 * Σ((x_i - y_i)² / l_i²))
35///
36/// Each dimension has its own length scale `l_i`, allowing the kernel
37/// to automatically weight features by their relevance.
38#[derive(Debug, Clone)]
39pub struct ArdRbfKernel {
40    /// Per-dimension length scales
41    length_scales: Vec<f64>,
42    /// Signal variance (output scale)
43    variance: f64,
44}
45
46impl ArdRbfKernel {
47    /// Create a new ARD RBF kernel with per-dimension length scales.
48    ///
49    /// # Arguments
50    /// * `length_scales` - Length scale for each dimension (all must be positive)
51    ///
52    /// # Example
53    /// ```rust
54    /// use tensorlogic_sklears_kernels::ard_kernel::ArdRbfKernel;
55    ///
56    /// let kernel = ArdRbfKernel::new(vec![1.0, 2.0, 0.5]).unwrap();
57    /// ```
58    pub fn new(length_scales: Vec<f64>) -> Result<Self> {
59        Self::with_variance(length_scales, 1.0)
60    }
61
62    /// Create ARD RBF kernel with custom signal variance.
63    ///
64    /// # Arguments
65    /// * `length_scales` - Per-dimension length scales
66    /// * `variance` - Signal variance (output scale, must be positive)
67    pub fn with_variance(length_scales: Vec<f64>, variance: f64) -> Result<Self> {
68        if length_scales.is_empty() {
69            return Err(KernelError::InvalidParameter {
70                parameter: "length_scales".to_string(),
71                value: "[]".to_string(),
72                reason: "must have at least one dimension".to_string(),
73            });
74        }
75
76        for (i, &ls) in length_scales.iter().enumerate() {
77            if ls <= 0.0 {
78                return Err(KernelError::InvalidParameter {
79                    parameter: format!("length_scales[{}]", i),
80                    value: ls.to_string(),
81                    reason: "all length scales must be positive".to_string(),
82                });
83            }
84        }
85
86        if variance <= 0.0 {
87            return Err(KernelError::InvalidParameter {
88                parameter: "variance".to_string(),
89                value: variance.to_string(),
90                reason: "variance must be positive".to_string(),
91            });
92        }
93
94        Ok(Self {
95            length_scales,
96            variance,
97        })
98    }
99
100    /// Get the length scales.
101    pub fn length_scales(&self) -> &[f64] {
102        &self.length_scales
103    }
104
105    /// Get the signal variance.
106    pub fn variance(&self) -> f64 {
107        self.variance
108    }
109
110    /// Get the number of dimensions.
111    pub fn ndim(&self) -> usize {
112        self.length_scales.len()
113    }
114
115    /// Compute the kernel gradient with respect to hyperparameters.
116    ///
117    /// Returns gradients for:
118    /// 1. Each length scale (one per dimension)
119    /// 2. The signal variance
120    ///
121    /// This is useful for hyperparameter optimization via gradient descent.
122    pub fn compute_gradient(&self, x: &[f64], y: &[f64]) -> Result<KernelGradient> {
123        if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
124            return Err(KernelError::DimensionMismatch {
125                expected: vec![self.length_scales.len()],
126                got: vec![x.len(), y.len()],
127                context: "ARD RBF kernel gradient".to_string(),
128            });
129        }
130
131        // Compute scaled squared differences
132        let mut sum_scaled_sq = 0.0;
133        let mut scaled_sq_diffs = Vec::with_capacity(self.length_scales.len());
134
135        for i in 0..self.length_scales.len() {
136            let diff = x[i] - y[i];
137            let ls = self.length_scales[i];
138            let scaled_sq = diff * diff / (ls * ls);
139            scaled_sq_diffs.push(scaled_sq);
140            sum_scaled_sq += scaled_sq;
141        }
142
143        let exp_term = (-0.5 * sum_scaled_sq).exp();
144        let k_value = self.variance * exp_term;
145
146        // Gradient w.r.t. each length scale: dk/dl_i = k * (x_i - y_i)² / l_i³
147        let grad_length_scales: Vec<f64> = scaled_sq_diffs
148            .iter()
149            .enumerate()
150            .map(|(i, &sq_diff)| {
151                let ls = self.length_scales[i];
152                k_value * sq_diff / ls
153            })
154            .collect();
155
156        // Gradient w.r.t. variance: dk/dσ² = exp_term
157        let grad_variance = exp_term;
158
159        Ok(KernelGradient {
160            value: k_value,
161            grad_length_scales,
162            grad_variance,
163        })
164    }
165}
166
167impl Kernel for ArdRbfKernel {
168    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
169        if x.len() != self.length_scales.len() {
170            return Err(KernelError::DimensionMismatch {
171                expected: vec![self.length_scales.len()],
172                got: vec![x.len()],
173                context: "ARD RBF kernel".to_string(),
174            });
175        }
176        if y.len() != self.length_scales.len() {
177            return Err(KernelError::DimensionMismatch {
178                expected: vec![self.length_scales.len()],
179                got: vec![y.len()],
180                context: "ARD RBF kernel".to_string(),
181            });
182        }
183
184        let mut sum_scaled_sq = 0.0;
185        for i in 0..self.length_scales.len() {
186            let diff = x[i] - y[i];
187            let ls = self.length_scales[i];
188            sum_scaled_sq += (diff * diff) / (ls * ls);
189        }
190
191        Ok(self.variance * (-0.5 * sum_scaled_sq).exp())
192    }
193
194    fn name(&self) -> &str {
195        "ARD-RBF"
196    }
197}
198
199/// ARD Matérn kernel with per-dimension length scales.
200///
201/// Supports Matérn with nu = 0.5 (exponential), 1.5, and 2.5.
202#[derive(Debug, Clone)]
203pub struct ArdMaternKernel {
204    /// Per-dimension length scales
205    length_scales: Vec<f64>,
206    /// Signal variance (output scale)
207    variance: f64,
208    /// Smoothness parameter (0.5, 1.5, or 2.5)
209    nu: f64,
210}
211
212impl ArdMaternKernel {
213    /// Create a new ARD Matérn kernel.
214    ///
215    /// # Arguments
216    /// * `length_scales` - Per-dimension length scales
217    /// * `nu` - Smoothness parameter (must be 0.5, 1.5, or 2.5)
218    pub fn new(length_scales: Vec<f64>, nu: f64) -> Result<Self> {
219        Self::with_variance(length_scales, nu, 1.0)
220    }
221
222    /// Create ARD Matérn kernel with custom variance.
223    pub fn with_variance(length_scales: Vec<f64>, nu: f64, variance: f64) -> Result<Self> {
224        if length_scales.is_empty() {
225            return Err(KernelError::InvalidParameter {
226                parameter: "length_scales".to_string(),
227                value: "[]".to_string(),
228                reason: "must have at least one dimension".to_string(),
229            });
230        }
231
232        for (i, &ls) in length_scales.iter().enumerate() {
233            if ls <= 0.0 {
234                return Err(KernelError::InvalidParameter {
235                    parameter: format!("length_scales[{}]", i),
236                    value: ls.to_string(),
237                    reason: "all length scales must be positive".to_string(),
238                });
239            }
240        }
241
242        if !((nu - 0.5).abs() < 1e-10 || (nu - 1.5).abs() < 1e-10 || (nu - 2.5).abs() < 1e-10) {
243            return Err(KernelError::InvalidParameter {
244                parameter: "nu".to_string(),
245                value: nu.to_string(),
246                reason: "nu must be 0.5, 1.5, or 2.5".to_string(),
247            });
248        }
249
250        if variance <= 0.0 {
251            return Err(KernelError::InvalidParameter {
252                parameter: "variance".to_string(),
253                value: variance.to_string(),
254                reason: "variance must be positive".to_string(),
255            });
256        }
257
258        Ok(Self {
259            length_scales,
260            variance,
261            nu,
262        })
263    }
264
265    /// Create ARD Matérn 1/2 kernel (exponential).
266    pub fn exponential(length_scales: Vec<f64>) -> Result<Self> {
267        Self::new(length_scales, 0.5)
268    }
269
270    /// Create ARD Matérn 3/2 kernel.
271    pub fn nu_3_2(length_scales: Vec<f64>) -> Result<Self> {
272        Self::new(length_scales, 1.5)
273    }
274
275    /// Create ARD Matérn 5/2 kernel.
276    pub fn nu_5_2(length_scales: Vec<f64>) -> Result<Self> {
277        Self::new(length_scales, 2.5)
278    }
279
280    /// Get the length scales.
281    pub fn length_scales(&self) -> &[f64] {
282        &self.length_scales
283    }
284
285    /// Get the signal variance.
286    pub fn variance(&self) -> f64 {
287        self.variance
288    }
289
290    /// Get the smoothness parameter nu.
291    pub fn nu(&self) -> f64 {
292        self.nu
293    }
294
295    /// Compute scaled Euclidean distance using ARD length scales.
296    fn scaled_distance(&self, x: &[f64], y: &[f64]) -> f64 {
297        let mut sum = 0.0;
298        for i in 0..self.length_scales.len() {
299            let diff = x[i] - y[i];
300            let ls = self.length_scales[i];
301            sum += (diff * diff) / (ls * ls);
302        }
303        sum.sqrt()
304    }
305}
306
307impl Kernel for ArdMaternKernel {
308    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
309        if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
310            return Err(KernelError::DimensionMismatch {
311                expected: vec![self.length_scales.len()],
312                got: vec![x.len(), y.len()],
313                context: "ARD Matérn kernel".to_string(),
314            });
315        }
316
317        let r = self.scaled_distance(x, y);
318
319        // Handle same point
320        if r < 1e-10 {
321            return Ok(self.variance);
322        }
323
324        let sqrt_2nu = (2.0 * self.nu).sqrt();
325        let scaled_r = sqrt_2nu * r;
326
327        let result = if (self.nu - 0.5).abs() < 1e-10 {
328            // Matérn 1/2 (exponential)
329            (-scaled_r).exp()
330        } else if (self.nu - 1.5).abs() < 1e-10 {
331            // Matérn 3/2
332            (1.0 + scaled_r) * (-scaled_r).exp()
333        } else {
334            // Matérn 5/2
335            (1.0 + scaled_r + scaled_r * scaled_r / 3.0) * (-scaled_r).exp()
336        };
337
338        Ok(self.variance * result)
339    }
340
341    fn name(&self) -> &str {
342        "ARD-Matérn"
343    }
344}
345
346/// ARD Rational Quadratic kernel.
347///
348/// K(x, y) = σ² * (1 + Σ((x_i - y_i)² / (2 * α * l_i²)))^(-α)
349#[derive(Debug, Clone)]
350pub struct ArdRationalQuadraticKernel {
351    /// Per-dimension length scales
352    length_scales: Vec<f64>,
353    /// Signal variance
354    variance: f64,
355    /// Scale mixture parameter
356    alpha: f64,
357}
358
359impl ArdRationalQuadraticKernel {
360    /// Create a new ARD Rational Quadratic kernel.
361    pub fn new(length_scales: Vec<f64>, alpha: f64) -> Result<Self> {
362        Self::with_variance(length_scales, alpha, 1.0)
363    }
364
365    /// Create with custom variance.
366    pub fn with_variance(length_scales: Vec<f64>, alpha: f64, variance: f64) -> Result<Self> {
367        if length_scales.is_empty() {
368            return Err(KernelError::InvalidParameter {
369                parameter: "length_scales".to_string(),
370                value: "[]".to_string(),
371                reason: "must have at least one dimension".to_string(),
372            });
373        }
374
375        for (i, &ls) in length_scales.iter().enumerate() {
376            if ls <= 0.0 {
377                return Err(KernelError::InvalidParameter {
378                    parameter: format!("length_scales[{}]", i),
379                    value: ls.to_string(),
380                    reason: "all length scales must be positive".to_string(),
381                });
382            }
383        }
384
385        if alpha <= 0.0 {
386            return Err(KernelError::InvalidParameter {
387                parameter: "alpha".to_string(),
388                value: alpha.to_string(),
389                reason: "alpha must be positive".to_string(),
390            });
391        }
392
393        if variance <= 0.0 {
394            return Err(KernelError::InvalidParameter {
395                parameter: "variance".to_string(),
396                value: variance.to_string(),
397                reason: "variance must be positive".to_string(),
398            });
399        }
400
401        Ok(Self {
402            length_scales,
403            variance,
404            alpha,
405        })
406    }
407
408    /// Get the length scales.
409    pub fn length_scales(&self) -> &[f64] {
410        &self.length_scales
411    }
412
413    /// Get the variance.
414    pub fn variance(&self) -> f64 {
415        self.variance
416    }
417
418    /// Get the alpha parameter.
419    pub fn alpha(&self) -> f64 {
420        self.alpha
421    }
422}
423
424impl Kernel for ArdRationalQuadraticKernel {
425    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
426        if x.len() != self.length_scales.len() || y.len() != self.length_scales.len() {
427            return Err(KernelError::DimensionMismatch {
428                expected: vec![self.length_scales.len()],
429                got: vec![x.len(), y.len()],
430                context: "ARD Rational Quadratic kernel".to_string(),
431            });
432        }
433
434        let mut sum_scaled_sq = 0.0;
435        for i in 0..self.length_scales.len() {
436            let diff = x[i] - y[i];
437            let ls = self.length_scales[i];
438            sum_scaled_sq += (diff * diff) / (ls * ls);
439        }
440
441        let term = 1.0 + sum_scaled_sq / (2.0 * self.alpha);
442        Ok(self.variance * term.powf(-self.alpha))
443    }
444
445    fn name(&self) -> &str {
446        "ARD-RationalQuadratic"
447    }
448}
449
450/// Gradient information for kernel hyperparameter optimization.
451#[derive(Debug, Clone)]
452pub struct KernelGradient {
453    /// The kernel value K(x, y)
454    pub value: f64,
455    /// Gradient with respect to each length scale
456    pub grad_length_scales: Vec<f64>,
457    /// Gradient with respect to the signal variance
458    pub grad_variance: f64,
459}
460
461/// Utility kernel: White Noise kernel for observation noise modeling.
462///
463/// K(x, y) = σ² if x == y, else 0
464///
465/// Used to model i.i.d. observation noise in Gaussian Processes.
466#[derive(Debug, Clone)]
467pub struct WhiteNoiseKernel {
468    /// Noise variance
469    variance: f64,
470}
471
472impl WhiteNoiseKernel {
473    /// Create a new white noise kernel.
474    ///
475    /// # Arguments
476    /// * `variance` - Noise variance (must be positive)
477    pub fn new(variance: f64) -> Result<Self> {
478        if variance <= 0.0 {
479            return Err(KernelError::InvalidParameter {
480                parameter: "variance".to_string(),
481                value: variance.to_string(),
482                reason: "variance must be positive".to_string(),
483            });
484        }
485        Ok(Self { variance })
486    }
487
488    /// Get the noise variance.
489    pub fn variance(&self) -> f64 {
490        self.variance
491    }
492}
493
494impl Kernel for WhiteNoiseKernel {
495    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
496        if x.len() != y.len() {
497            return Err(KernelError::DimensionMismatch {
498                expected: vec![x.len()],
499                got: vec![y.len()],
500                context: "White Noise kernel".to_string(),
501            });
502        }
503
504        // Check if x and y are the same point (within tolerance)
505        let is_same = x.iter().zip(y.iter()).all(|(a, b)| (a - b).abs() < 1e-10);
506
507        if is_same {
508            Ok(self.variance)
509        } else {
510            Ok(0.0)
511        }
512    }
513
514    fn name(&self) -> &str {
515        "WhiteNoise"
516    }
517}
518
519/// Constant kernel: K(x, y) = σ²
520///
521/// Produces constant predictions. Useful as a building block in composite kernels.
522#[derive(Debug, Clone)]
523pub struct ConstantKernel {
524    /// Constant value (variance)
525    variance: f64,
526}
527
528impl ConstantKernel {
529    /// Create a new constant kernel.
530    pub fn new(variance: f64) -> Result<Self> {
531        if variance <= 0.0 {
532            return Err(KernelError::InvalidParameter {
533                parameter: "variance".to_string(),
534                value: variance.to_string(),
535                reason: "variance must be positive".to_string(),
536            });
537        }
538        Ok(Self { variance })
539    }
540
541    /// Get the variance.
542    pub fn variance(&self) -> f64 {
543        self.variance
544    }
545}
546
547impl Kernel for ConstantKernel {
548    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
549        if x.len() != y.len() {
550            return Err(KernelError::DimensionMismatch {
551                expected: vec![x.len()],
552                got: vec![y.len()],
553                context: "Constant kernel".to_string(),
554            });
555        }
556        Ok(self.variance)
557    }
558
559    fn name(&self) -> &str {
560        "Constant"
561    }
562}
563
564/// Dot Product kernel (Linear kernel with variance and shift).
565///
566/// K(x, y) = σ² + σ_b² + x · y
567///
568/// Useful for Bayesian linear regression. The offset parameter allows
569/// the linear model to have a non-zero mean.
570#[derive(Debug, Clone)]
571pub struct DotProductKernel {
572    /// Signal variance (scaling)
573    variance: f64,
574    /// Offset variance (bias)
575    variance_bias: f64,
576}
577
578impl DotProductKernel {
579    /// Create a new dot product kernel.
580    pub fn new(variance: f64, variance_bias: f64) -> Result<Self> {
581        if variance < 0.0 {
582            return Err(KernelError::InvalidParameter {
583                parameter: "variance".to_string(),
584                value: variance.to_string(),
585                reason: "variance must be non-negative".to_string(),
586            });
587        }
588        if variance_bias < 0.0 {
589            return Err(KernelError::InvalidParameter {
590                parameter: "variance_bias".to_string(),
591                value: variance_bias.to_string(),
592                reason: "variance_bias must be non-negative".to_string(),
593            });
594        }
595        Ok(Self {
596            variance,
597            variance_bias,
598        })
599    }
600
601    /// Create a simple dot product kernel (variance=1, no bias).
602    pub fn simple() -> Self {
603        Self {
604            variance: 1.0,
605            variance_bias: 0.0,
606        }
607    }
608
609    /// Get the variance.
610    pub fn variance(&self) -> f64 {
611        self.variance
612    }
613
614    /// Get the bias variance.
615    pub fn variance_bias(&self) -> f64 {
616        self.variance_bias
617    }
618}
619
620impl Kernel for DotProductKernel {
621    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
622        if x.len() != y.len() {
623            return Err(KernelError::DimensionMismatch {
624                expected: vec![x.len()],
625                got: vec![y.len()],
626                context: "Dot Product kernel".to_string(),
627            });
628        }
629
630        let dot: f64 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
631        Ok(self.variance_bias + self.variance * dot)
632    }
633
634    fn name(&self) -> &str {
635        "DotProduct"
636    }
637}
638
639/// Scaled kernel wrapper that multiplies a kernel by a variance parameter.
640///
641/// K_scaled(x, y) = σ² * K(x, y)
642///
643/// This is useful for controlling the output scale of any kernel.
644#[derive(Debug, Clone)]
645pub struct ScaledKernel<K: Kernel> {
646    /// The base kernel
647    kernel: K,
648    /// The scaling factor (variance)
649    variance: f64,
650}
651
652impl<K: Kernel> ScaledKernel<K> {
653    /// Create a scaled kernel.
654    pub fn new(kernel: K, variance: f64) -> Result<Self> {
655        if variance <= 0.0 {
656            return Err(KernelError::InvalidParameter {
657                parameter: "variance".to_string(),
658                value: variance.to_string(),
659                reason: "variance must be positive".to_string(),
660            });
661        }
662        Ok(Self { kernel, variance })
663    }
664
665    /// Get the base kernel.
666    pub fn kernel(&self) -> &K {
667        &self.kernel
668    }
669
670    /// Get the variance.
671    pub fn variance(&self) -> f64 {
672        self.variance
673    }
674}
675
676impl<K: Kernel> Kernel for ScaledKernel<K> {
677    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
678        let base_value = self.kernel.compute(x, y)?;
679        Ok(self.variance * base_value)
680    }
681
682    fn name(&self) -> &str {
683        "Scaled"
684    }
685
686    fn is_psd(&self) -> bool {
687        self.kernel.is_psd()
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    // ===== ARD RBF Kernel Tests =====
696
697    #[test]
698    fn test_ard_rbf_kernel_basic() {
699        let kernel = ArdRbfKernel::new(vec![1.0, 1.0, 1.0]).unwrap();
700        assert_eq!(kernel.name(), "ARD-RBF");
701        assert_eq!(kernel.ndim(), 3);
702
703        let x = vec![1.0, 2.0, 3.0];
704        let y = vec![1.0, 2.0, 3.0];
705
706        // Self-similarity should be variance (1.0)
707        let sim = kernel.compute(&x, &y).unwrap();
708        assert!((sim - 1.0).abs() < 1e-10);
709    }
710
711    #[test]
712    fn test_ard_rbf_kernel_different_length_scales() {
713        // High length scale in dimension 0 makes that dimension less important
714        let kernel = ArdRbfKernel::new(vec![10.0, 1.0, 1.0]).unwrap();
715
716        let x = vec![0.0, 0.0, 0.0];
717        let y1 = vec![1.0, 0.0, 0.0]; // Difference in dim 0 (large length scale)
718        let y2 = vec![0.0, 1.0, 0.0]; // Difference in dim 1 (small length scale)
719
720        let sim1 = kernel.compute(&x, &y1).unwrap();
721        let sim2 = kernel.compute(&x, &y2).unwrap();
722
723        // y1 should be MORE similar because dim 0 has large length scale (less relevant)
724        assert!(sim1 > sim2);
725    }
726
727    #[test]
728    fn test_ard_rbf_kernel_with_variance() {
729        let kernel = ArdRbfKernel::with_variance(vec![1.0, 1.0], 2.0).unwrap();
730        assert!((kernel.variance() - 2.0).abs() < 1e-10);
731
732        let x = vec![0.0, 0.0];
733        let sim = kernel.compute(&x, &x).unwrap();
734        assert!((sim - 2.0).abs() < 1e-10); // Self-similarity = variance
735    }
736
737    #[test]
738    fn test_ard_rbf_kernel_gradient() {
739        let kernel = ArdRbfKernel::new(vec![1.0, 2.0]).unwrap();
740        let x = vec![0.0, 0.0];
741        let y = vec![1.0, 1.0];
742
743        let grad = kernel.compute_gradient(&x, &y).unwrap();
744
745        // Check that value matches compute
746        let value = kernel.compute(&x, &y).unwrap();
747        assert!((grad.value - value).abs() < 1e-10);
748
749        // Gradients should have correct dimensions
750        assert_eq!(grad.grad_length_scales.len(), 2);
751    }
752
753    #[test]
754    fn test_ard_rbf_kernel_invalid_empty() {
755        let result = ArdRbfKernel::new(vec![]);
756        assert!(result.is_err());
757    }
758
759    #[test]
760    fn test_ard_rbf_kernel_invalid_negative() {
761        let result = ArdRbfKernel::new(vec![1.0, -1.0, 1.0]);
762        assert!(result.is_err());
763    }
764
765    #[test]
766    fn test_ard_rbf_kernel_invalid_variance() {
767        let result = ArdRbfKernel::with_variance(vec![1.0], 0.0);
768        assert!(result.is_err());
769    }
770
771    #[test]
772    fn test_ard_rbf_kernel_dimension_mismatch() {
773        let kernel = ArdRbfKernel::new(vec![1.0, 1.0]).unwrap();
774        let x = vec![1.0, 2.0, 3.0]; // 3 dims
775        let y = vec![1.0, 2.0]; // 2 dims
776
777        assert!(kernel.compute(&x, &y).is_err());
778    }
779
780    #[test]
781    fn test_ard_rbf_kernel_symmetry() {
782        let kernel = ArdRbfKernel::new(vec![1.0, 2.0, 0.5]).unwrap();
783        let x = vec![1.0, 2.0, 3.0];
784        let y = vec![4.0, 5.0, 6.0];
785
786        let k_xy = kernel.compute(&x, &y).unwrap();
787        let k_yx = kernel.compute(&y, &x).unwrap();
788        assert!((k_xy - k_yx).abs() < 1e-10);
789    }
790
791    // ===== ARD Matérn Kernel Tests =====
792
793    #[test]
794    fn test_ard_matern_kernel_nu_3_2() {
795        let kernel = ArdMaternKernel::nu_3_2(vec![1.0, 1.0]).unwrap();
796        assert_eq!(kernel.name(), "ARD-Matérn");
797        assert!((kernel.nu() - 1.5).abs() < 1e-10);
798
799        let x = vec![0.0, 0.0];
800        let sim = kernel.compute(&x, &x).unwrap();
801        assert!((sim - 1.0).abs() < 1e-10);
802    }
803
804    #[test]
805    fn test_ard_matern_kernel_nu_5_2() {
806        let kernel = ArdMaternKernel::nu_5_2(vec![1.0, 2.0]).unwrap();
807        assert!((kernel.nu() - 2.5).abs() < 1e-10);
808
809        let x = vec![0.0, 0.0];
810        let y = vec![0.5, 0.5];
811        let sim = kernel.compute(&x, &y).unwrap();
812        assert!(sim > 0.0 && sim < 1.0);
813    }
814
815    #[test]
816    fn test_ard_matern_kernel_exponential() {
817        let kernel = ArdMaternKernel::exponential(vec![1.0]).unwrap();
818        assert!((kernel.nu() - 0.5).abs() < 1e-10);
819    }
820
821    #[test]
822    fn test_ard_matern_kernel_invalid_nu() {
823        // Only 0.5, 1.5, 2.5 are supported
824        let result = ArdMaternKernel::new(vec![1.0], 1.0);
825        assert!(result.is_err());
826    }
827
828    #[test]
829    fn test_ard_matern_kernel_different_length_scales() {
830        let kernel = ArdMaternKernel::nu_3_2(vec![10.0, 1.0]).unwrap();
831
832        let x = vec![0.0, 0.0];
833        let y1 = vec![1.0, 0.0];
834        let y2 = vec![0.0, 1.0];
835
836        let sim1 = kernel.compute(&x, &y1).unwrap();
837        let sim2 = kernel.compute(&x, &y2).unwrap();
838
839        // Larger length scale in dim 0 makes it less relevant
840        assert!(sim1 > sim2);
841    }
842
843    // ===== ARD Rational Quadratic Kernel Tests =====
844
845    #[test]
846    fn test_ard_rq_kernel_basic() {
847        let kernel = ArdRationalQuadraticKernel::new(vec![1.0, 1.0], 2.0).unwrap();
848        assert_eq!(kernel.name(), "ARD-RationalQuadratic");
849
850        let x = vec![0.0, 0.0];
851        let sim = kernel.compute(&x, &x).unwrap();
852        assert!((sim - 1.0).abs() < 1e-10);
853    }
854
855    #[test]
856    fn test_ard_rq_kernel_with_variance() {
857        let kernel = ArdRationalQuadraticKernel::with_variance(vec![1.0], 2.0, 3.0).unwrap();
858        assert!((kernel.variance() - 3.0).abs() < 1e-10);
859
860        let x = vec![0.0];
861        let sim = kernel.compute(&x, &x).unwrap();
862        assert!((sim - 3.0).abs() < 1e-10);
863    }
864
865    // ===== White Noise Kernel Tests =====
866
867    #[test]
868    fn test_white_noise_kernel_same_point() {
869        let kernel = WhiteNoiseKernel::new(0.1).unwrap();
870        assert_eq!(kernel.name(), "WhiteNoise");
871
872        let x = vec![1.0, 2.0, 3.0];
873        let sim = kernel.compute(&x, &x).unwrap();
874        assert!((sim - 0.1).abs() < 1e-10);
875    }
876
877    #[test]
878    fn test_white_noise_kernel_different_points() {
879        let kernel = WhiteNoiseKernel::new(0.1).unwrap();
880
881        let x = vec![1.0, 2.0, 3.0];
882        let y = vec![1.0, 2.0, 3.1]; // Slightly different
883        let sim = kernel.compute(&x, &y).unwrap();
884        assert!(sim.abs() < 1e-10); // Should be 0
885    }
886
887    #[test]
888    fn test_white_noise_kernel_invalid() {
889        let result = WhiteNoiseKernel::new(0.0);
890        assert!(result.is_err());
891
892        let result = WhiteNoiseKernel::new(-1.0);
893        assert!(result.is_err());
894    }
895
896    // ===== Constant Kernel Tests =====
897
898    #[test]
899    fn test_constant_kernel() {
900        let kernel = ConstantKernel::new(2.5).unwrap();
901        assert_eq!(kernel.name(), "Constant");
902
903        let x = vec![1.0, 2.0];
904        let y = vec![3.0, 4.0];
905
906        let sim = kernel.compute(&x, &y).unwrap();
907        assert!((sim - 2.5).abs() < 1e-10);
908    }
909
910    #[test]
911    fn test_constant_kernel_invalid() {
912        assert!(ConstantKernel::new(0.0).is_err());
913        assert!(ConstantKernel::new(-1.0).is_err());
914    }
915
916    // ===== Dot Product Kernel Tests =====
917
918    #[test]
919    fn test_dot_product_kernel_simple() {
920        let kernel = DotProductKernel::simple();
921        assert_eq!(kernel.name(), "DotProduct");
922
923        let x = vec![1.0, 2.0, 3.0];
924        let y = vec![4.0, 5.0, 6.0];
925
926        // dot = 1*4 + 2*5 + 3*6 = 32
927        let sim = kernel.compute(&x, &y).unwrap();
928        assert!((sim - 32.0).abs() < 1e-10);
929    }
930
931    #[test]
932    fn test_dot_product_kernel_with_bias() {
933        let kernel = DotProductKernel::new(1.0, 5.0).unwrap();
934
935        let x = vec![1.0, 0.0];
936        let y = vec![0.0, 1.0]; // Orthogonal
937
938        // dot = 0, result = bias + variance * dot = 5 + 0 = 5
939        let sim = kernel.compute(&x, &y).unwrap();
940        assert!((sim - 5.0).abs() < 1e-10);
941    }
942
943    #[test]
944    fn test_dot_product_kernel_with_variance() {
945        let kernel = DotProductKernel::new(2.0, 0.0).unwrap();
946
947        let x = vec![1.0, 2.0];
948        let y = vec![3.0, 4.0];
949
950        // dot = 11, result = 2 * 11 = 22
951        let sim = kernel.compute(&x, &y).unwrap();
952        assert!((sim - 22.0).abs() < 1e-10);
953    }
954
955    // ===== Scaled Kernel Tests =====
956
957    #[test]
958    fn test_scaled_kernel() {
959        use crate::tensor_kernels::LinearKernel;
960
961        let base = LinearKernel::new();
962        let scaled = ScaledKernel::new(base, 2.0).unwrap();
963        assert_eq!(scaled.name(), "Scaled");
964
965        let x = vec![1.0, 2.0, 3.0];
966        let y = vec![4.0, 5.0, 6.0];
967
968        // Linear: dot = 32, scaled: 2 * 32 = 64
969        let sim = scaled.compute(&x, &y).unwrap();
970        assert!((sim - 64.0).abs() < 1e-10);
971    }
972
973    #[test]
974    fn test_scaled_kernel_invalid() {
975        use crate::tensor_kernels::LinearKernel;
976
977        let base = LinearKernel::new();
978        let result = ScaledKernel::new(base, 0.0);
979        assert!(result.is_err());
980    }
981
982    #[test]
983    fn test_scaled_kernel_psd() {
984        use crate::tensor_kernels::LinearKernel;
985
986        let base = LinearKernel::new();
987        let scaled = ScaledKernel::new(base, 2.0).unwrap();
988        assert!(scaled.is_psd());
989    }
990
991    // ===== Integration Tests =====
992
993    #[test]
994    fn test_ard_kernels_symmetry() {
995        let kernels: Vec<Box<dyn Kernel>> = vec![
996            Box::new(ArdRbfKernel::new(vec![1.0, 2.0]).unwrap()),
997            Box::new(ArdMaternKernel::nu_3_2(vec![1.0, 2.0]).unwrap()),
998            Box::new(ArdRationalQuadraticKernel::new(vec![1.0, 2.0], 2.0).unwrap()),
999        ];
1000
1001        let x = vec![1.0, 2.0];
1002        let y = vec![3.0, 4.0];
1003
1004        for kernel in kernels {
1005            let k_xy = kernel.compute(&x, &y).unwrap();
1006            let k_yx = kernel.compute(&y, &x).unwrap();
1007            assert!(
1008                (k_xy - k_yx).abs() < 1e-10,
1009                "{} not symmetric",
1010                kernel.name()
1011            );
1012        }
1013    }
1014
1015    #[test]
1016    fn test_utility_kernels_symmetry() {
1017        let kernels: Vec<Box<dyn Kernel>> = vec![
1018            Box::new(WhiteNoiseKernel::new(0.1).unwrap()),
1019            Box::new(ConstantKernel::new(1.0).unwrap()),
1020            Box::new(DotProductKernel::simple()),
1021        ];
1022
1023        let x = vec![1.0, 2.0, 3.0];
1024        let y = vec![4.0, 5.0, 6.0];
1025
1026        for kernel in kernels {
1027            let k_xy = kernel.compute(&x, &y).unwrap();
1028            let k_yx = kernel.compute(&y, &x).unwrap();
1029            assert!(
1030                (k_xy - k_yx).abs() < 1e-10,
1031                "{} not symmetric",
1032                kernel.name()
1033            );
1034        }
1035    }
1036}