sklears_gaussian_process/
kernels.rs

1//! Kernel functions for Gaussian Process models
2//!
3//! This module provides various kernel functions that can be used with
4//! Gaussian Process models. All implementations comply with SciRS2 Policy.
5
6// SciRS2 Policy - Use scirs2-autograd for ndarray types and array operations
7use scirs2_core::ndarray::{Array2, ArrayView1};
8// SciRS2 Policy - Use scirs2-core for random operations
9use sklears_core::error::{Result as SklResult, SklearsError};
10
11// Core kernel trait - import and re-export from the kernel_trait module
12pub use crate::kernel_trait::Kernel;
13
14/// RBF (Radial Basis Function) Kernel
15#[derive(Debug, Clone)]
16pub struct RBF {
17    length_scale: f64,
18}
19
20impl RBF {
21    pub fn new(length_scale: f64) -> Self {
22        Self { length_scale }
23    }
24}
25
26impl Kernel for RBF {
27    fn compute_kernel_matrix(
28        &self,
29        X1: &Array2<f64>,
30        X2: Option<&Array2<f64>>,
31    ) -> SklResult<Array2<f64>> {
32        let X2 = X2.unwrap_or(X1);
33        let n1 = X1.nrows();
34        let n2 = X2.nrows();
35        let mut K = Array2::<f64>::zeros((n1, n2));
36
37        for i in 0..n1 {
38            for j in 0..n2 {
39                let x1 = X1.row(i);
40                let x2 = X2.row(j);
41                K[[i, j]] = self.kernel(&x1, &x2);
42            }
43        }
44        Ok(K)
45    }
46
47    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
48        let mut sq_dist = 0.0;
49        for (a, b) in x1.iter().zip(x2.iter()) {
50            sq_dist += (a - b).powi(2);
51        }
52        (-sq_dist / (2.0 * self.length_scale.powi(2))).exp()
53    }
54
55    fn get_params(&self) -> Vec<f64> {
56        vec![self.length_scale]
57    }
58
59    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
60        if params.len() != 1 {
61            return Err(SklearsError::InvalidInput(
62                "RBF kernel requires exactly 1 parameter".to_string(),
63            ));
64        }
65        self.length_scale = params[0];
66        Ok(())
67    }
68
69    fn clone_box(&self) -> Box<dyn Kernel> {
70        Box::new(self.clone())
71    }
72}
73
74/// ARD RBF (Automatic Relevance Determination RBF) Kernel
75///
76/// This kernel is similar to RBF but has a separate length scale for each dimension,
77/// allowing it to automatically determine which dimensions are most relevant for
78/// the prediction task.
79///
80/// The kernel is defined as:
81/// k(x, x') = exp(-0.5 * sum((x_i - x'_i)^2 / length_scales[i]^2))
82///
83/// # Examples
84///
85/// ```ignore
86/// use sklears_gaussian_process::kernels::ARDRBF;
87/// use scirs2_core::ndarray::Array1;
88///
89/// // Create ARD RBF kernel with different length scales for each dimension
90/// let length_scales = Array1::from_vec(vec![1.0, 2.0, 0.5]);
91/// let kernel = ARDRBF::new(length_scales);
92/// ```
93#[derive(Debug, Clone)]
94pub struct ARDRBF {
95    /// Length scale for each dimension
96    length_scales: scirs2_core::ndarray::Array1<f64>,
97}
98
99impl ARDRBF {
100    /// Create a new ARD RBF kernel with specified length scales for each dimension
101    ///
102    /// # Arguments
103    ///
104    /// * `length_scales` - Array of length scales, one for each dimension
105    pub fn new(length_scales: scirs2_core::ndarray::Array1<f64>) -> Self {
106        Self { length_scales }
107    }
108
109    /// Create a new ARD RBF kernel with uniform length scales
110    ///
111    /// # Arguments
112    ///
113    /// * `n_dims` - Number of dimensions
114    /// * `length_scale` - Initial length scale value for all dimensions
115    pub fn new_uniform(n_dims: usize, length_scale: f64) -> Self {
116        Self {
117            length_scales: scirs2_core::ndarray::Array1::from_elem(n_dims, length_scale),
118        }
119    }
120
121    /// Get the number of dimensions
122    pub fn n_dimensions(&self) -> usize {
123        self.length_scales.len()
124    }
125}
126
127impl Kernel for ARDRBF {
128    fn compute_kernel_matrix(
129        &self,
130        X1: &Array2<f64>,
131        X2: Option<&Array2<f64>>,
132    ) -> SklResult<Array2<f64>> {
133        let X2 = X2.unwrap_or(X1);
134        let n1 = X1.nrows();
135        let n2 = X2.nrows();
136
137        // Validate dimensions
138        if X1.ncols() != self.length_scales.len() {
139            return Err(SklearsError::InvalidInput(format!(
140                "X1 has {} dimensions but kernel expects {}",
141                X1.ncols(),
142                self.length_scales.len()
143            )));
144        }
145        if X2.ncols() != self.length_scales.len() {
146            return Err(SklearsError::InvalidInput(format!(
147                "X2 has {} dimensions but kernel expects {}",
148                X2.ncols(),
149                self.length_scales.len()
150            )));
151        }
152
153        let mut K = Array2::<f64>::zeros((n1, n2));
154
155        for i in 0..n1 {
156            for j in 0..n2 {
157                let x1 = X1.row(i);
158                let x2 = X2.row(j);
159                K[[i, j]] = self.kernel(&x1, &x2);
160            }
161        }
162        Ok(K)
163    }
164
165    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
166        // Compute weighted squared distance: sum((x_i - x'_i)^2 / l_i^2)
167        let mut weighted_sq_dist = 0.0;
168        for ((a, b), length_scale) in x1.iter().zip(x2.iter()).zip(self.length_scales.iter()) {
169            let diff = a - b;
170            weighted_sq_dist += (diff * diff) / (length_scale * length_scale);
171        }
172        (-0.5 * weighted_sq_dist).exp()
173    }
174
175    fn get_params(&self) -> Vec<f64> {
176        self.length_scales.to_vec()
177    }
178
179    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
180        if params.len() != self.length_scales.len() {
181            return Err(SklearsError::InvalidInput(format!(
182                "ARD RBF kernel requires exactly {} parameters (one per dimension), got {}",
183                self.length_scales.len(),
184                params.len()
185            )));
186        }
187        for (i, &param) in params.iter().enumerate() {
188            self.length_scales[i] = param;
189        }
190        Ok(())
191    }
192
193    fn clone_box(&self) -> Box<dyn Kernel> {
194        Box::new(self.clone())
195    }
196}
197
198/// Matérn Kernel
199#[derive(Debug, Clone)]
200pub struct Matern {
201    length_scale: f64,
202    nu: f64,
203}
204
205impl Matern {
206    pub fn new(length_scale: f64, nu: f64) -> Self {
207        Self { length_scale, nu }
208    }
209}
210
211impl Kernel for Matern {
212    fn compute_kernel_matrix(
213        &self,
214        X1: &Array2<f64>,
215        X2: Option<&Array2<f64>>,
216    ) -> SklResult<Array2<f64>> {
217        let X2 = X2.unwrap_or(X1);
218        let n1 = X1.nrows();
219        let n2 = X2.nrows();
220        let mut K = Array2::<f64>::zeros((n1, n2));
221
222        for i in 0..n1 {
223            for j in 0..n2 {
224                let x1 = X1.row(i);
225                let x2 = X2.row(j);
226                K[[i, j]] = self.kernel(&x1, &x2);
227            }
228        }
229        Ok(K)
230    }
231
232    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
233        let mut sq_dist = 0.0;
234        for (a, b) in x1.iter().zip(x2.iter()) {
235            sq_dist += (a - b).powi(2);
236        }
237        let dist = sq_dist.sqrt();
238
239        if dist == 0.0 {
240            return 1.0;
241        }
242
243        let sqrt_3_dist = (3.0_f64).sqrt() * dist / self.length_scale;
244        (1.0 + sqrt_3_dist) * (-sqrt_3_dist).exp()
245    }
246
247    fn get_params(&self) -> Vec<f64> {
248        vec![self.length_scale, self.nu]
249    }
250
251    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
252        if params.len() != 2 {
253            return Err(SklearsError::InvalidInput(
254                "Matern kernel requires exactly 2 parameters".to_string(),
255            ));
256        }
257        self.length_scale = params[0];
258        self.nu = params[1];
259        Ok(())
260    }
261
262    fn clone_box(&self) -> Box<dyn Kernel> {
263        Box::new(self.clone())
264    }
265}
266
267/// Linear Kernel
268#[derive(Debug, Clone)]
269pub struct Linear {
270    sigma_0_sq: f64,
271    sigma_1_sq: f64,
272}
273
274impl Linear {
275    pub fn new(sigma_0_sq: f64, sigma_1_sq: f64) -> Self {
276        Self {
277            sigma_0_sq,
278            sigma_1_sq,
279        }
280    }
281}
282
283impl Kernel for Linear {
284    fn compute_kernel_matrix(
285        &self,
286        X1: &Array2<f64>,
287        X2: Option<&Array2<f64>>,
288    ) -> SklResult<Array2<f64>> {
289        let X2 = X2.unwrap_or(X1);
290        let n1 = X1.nrows();
291        let n2 = X2.nrows();
292        let mut K = Array2::<f64>::zeros((n1, n2));
293
294        for i in 0..n1 {
295            for j in 0..n2 {
296                let x1 = X1.row(i);
297                let x2 = X2.row(j);
298                K[[i, j]] = self.kernel(&x1, &x2);
299            }
300        }
301        Ok(K)
302    }
303
304    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
305        let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
306        self.sigma_0_sq + self.sigma_1_sq * dot_product
307    }
308
309    fn get_params(&self) -> Vec<f64> {
310        vec![self.sigma_0_sq, self.sigma_1_sq]
311    }
312
313    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
314        if params.len() != 2 {
315            return Err(SklearsError::InvalidInput(
316                "Linear kernel requires exactly 2 parameters".to_string(),
317            ));
318        }
319        self.sigma_0_sq = params[0];
320        self.sigma_1_sq = params[1];
321        Ok(())
322    }
323
324    fn clone_box(&self) -> Box<dyn Kernel> {
325        Box::new(self.clone())
326    }
327}
328
329/// Polynomial Kernel
330#[derive(Debug, Clone)]
331pub struct Polynomial {
332    gamma: f64,
333    coef0: f64,
334    degree: f64,
335}
336
337impl Polynomial {
338    pub fn new(gamma: f64, coef0: f64, degree: f64) -> Self {
339        Self {
340            gamma,
341            coef0,
342            degree,
343        }
344    }
345}
346
347impl Kernel for Polynomial {
348    fn compute_kernel_matrix(
349        &self,
350        X1: &Array2<f64>,
351        X2: Option<&Array2<f64>>,
352    ) -> SklResult<Array2<f64>> {
353        let X2 = X2.unwrap_or(X1);
354        let n1 = X1.nrows();
355        let n2 = X2.nrows();
356        let mut K = Array2::<f64>::zeros((n1, n2));
357
358        for i in 0..n1 {
359            for j in 0..n2 {
360                let x1 = X1.row(i);
361                let x2 = X2.row(j);
362                K[[i, j]] = self.kernel(&x1, &x2);
363            }
364        }
365        Ok(K)
366    }
367
368    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
369        let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
370        (self.gamma * dot_product + self.coef0).powf(self.degree)
371    }
372
373    fn get_params(&self) -> Vec<f64> {
374        vec![self.gamma, self.coef0, self.degree]
375    }
376
377    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
378        if params.len() != 3 {
379            return Err(SklearsError::InvalidInput(
380                "Polynomial kernel requires exactly 3 parameters".to_string(),
381            ));
382        }
383        self.gamma = params[0];
384        self.coef0 = params[1];
385        self.degree = params[2];
386        Ok(())
387    }
388
389    fn clone_box(&self) -> Box<dyn Kernel> {
390        Box::new(self.clone())
391    }
392}
393
394/// Rational Quadratic Kernel
395#[derive(Debug, Clone)]
396pub struct RationalQuadratic {
397    length_scale: f64,
398    alpha: f64,
399}
400
401impl RationalQuadratic {
402    pub fn new(length_scale: f64, alpha: f64) -> Self {
403        Self {
404            length_scale,
405            alpha,
406        }
407    }
408}
409
410impl Kernel for RationalQuadratic {
411    fn compute_kernel_matrix(
412        &self,
413        X1: &Array2<f64>,
414        X2: Option<&Array2<f64>>,
415    ) -> SklResult<Array2<f64>> {
416        let X2 = X2.unwrap_or(X1);
417        let n1 = X1.nrows();
418        let n2 = X2.nrows();
419        let mut K = Array2::<f64>::zeros((n1, n2));
420
421        for i in 0..n1 {
422            for j in 0..n2 {
423                let x1 = X1.row(i);
424                let x2 = X2.row(j);
425                K[[i, j]] = self.kernel(&x1, &x2);
426            }
427        }
428        Ok(K)
429    }
430
431    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
432        let mut sq_dist = 0.0;
433        for (a, b) in x1.iter().zip(x2.iter()) {
434            sq_dist += (a - b).powi(2);
435        }
436        (1.0 + sq_dist / (2.0 * self.alpha * self.length_scale.powi(2))).powf(-self.alpha)
437    }
438
439    fn get_params(&self) -> Vec<f64> {
440        vec![self.length_scale, self.alpha]
441    }
442
443    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
444        if params.len() != 2 {
445            return Err(SklearsError::InvalidInput(
446                "RationalQuadratic kernel requires exactly 2 parameters".to_string(),
447            ));
448        }
449        self.length_scale = params[0];
450        self.alpha = params[1];
451        Ok(())
452    }
453
454    fn clone_box(&self) -> Box<dyn Kernel> {
455        Box::new(self.clone())
456    }
457}
458
459/// Exp-Sine-Squared (Periodic) Kernel
460#[derive(Debug, Clone)]
461pub struct ExpSineSquared {
462    length_scale: f64,
463    periodicity: f64,
464}
465
466impl ExpSineSquared {
467    pub fn new(length_scale: f64, periodicity: f64) -> Self {
468        Self {
469            length_scale,
470            periodicity,
471        }
472    }
473}
474
475impl Kernel for ExpSineSquared {
476    fn compute_kernel_matrix(
477        &self,
478        X1: &Array2<f64>,
479        X2: Option<&Array2<f64>>,
480    ) -> SklResult<Array2<f64>> {
481        let X2 = X2.unwrap_or(X1);
482        let n1 = X1.nrows();
483        let n2 = X2.nrows();
484        let mut K = Array2::<f64>::zeros((n1, n2));
485
486        for i in 0..n1 {
487            for j in 0..n2 {
488                let x1 = X1.row(i);
489                let x2 = X2.row(j);
490                K[[i, j]] = self.kernel(&x1, &x2);
491            }
492        }
493        Ok(K)
494    }
495
496    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
497        let dist = x1
498            .iter()
499            .zip(x2.iter())
500            .map(|(a, b)| (a - b).powi(2))
501            .sum::<f64>()
502            .sqrt();
503        let sin_term = (std::f64::consts::PI * dist / self.periodicity).sin();
504        (-2.0 * sin_term.powi(2) / self.length_scale.powi(2)).exp()
505    }
506
507    fn get_params(&self) -> Vec<f64> {
508        vec![self.length_scale, self.periodicity]
509    }
510
511    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
512        if params.len() != 2 {
513            return Err(SklearsError::InvalidInput(
514                "ExpSineSquared kernel requires exactly 2 parameters".to_string(),
515            ));
516        }
517        self.length_scale = params[0];
518        self.periodicity = params[1];
519        Ok(())
520    }
521
522    fn clone_box(&self) -> Box<dyn Kernel> {
523        Box::new(self.clone())
524    }
525}
526
527/// White (noise) Kernel
528#[derive(Debug, Clone)]
529pub struct WhiteKernel {
530    noise_level: f64,
531}
532
533impl WhiteKernel {
534    pub fn new(noise_level: f64) -> Self {
535        Self { noise_level }
536    }
537}
538
539impl Kernel for WhiteKernel {
540    fn compute_kernel_matrix(
541        &self,
542        X1: &Array2<f64>,
543        X2: Option<&Array2<f64>>,
544    ) -> SklResult<Array2<f64>> {
545        let n1 = X1.nrows();
546        let n2 = X2.map_or(n1, |x| x.nrows());
547        let mut K = Array2::<f64>::zeros((n1, n2));
548
549        // White kernel is only non-zero on the diagonal when X1 == X2
550        if X2.is_none() {
551            for i in 0..n1 {
552                K[[i, i]] = self.noise_level;
553            }
554        }
555        Ok(K)
556    }
557
558    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
559        // Check if vectors are identical
560        let identical = x1.iter().zip(x2.iter()).all(|(a, b)| (a - b).abs() < 1e-10);
561        if identical {
562            self.noise_level
563        } else {
564            0.0
565        }
566    }
567
568    fn get_params(&self) -> Vec<f64> {
569        vec![self.noise_level]
570    }
571
572    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
573        if params.len() != 1 {
574            return Err(SklearsError::InvalidInput(
575                "WhiteKernel requires exactly 1 parameter".to_string(),
576            ));
577        }
578        self.noise_level = params[0];
579        Ok(())
580    }
581
582    fn clone_box(&self) -> Box<dyn Kernel> {
583        Box::new(self.clone())
584    }
585}
586
587/// Constant Kernel
588#[derive(Debug, Clone)]
589pub struct ConstantKernel {
590    constant_value: f64,
591}
592
593impl ConstantKernel {
594    pub fn new(constant_value: f64) -> Self {
595        Self { constant_value }
596    }
597}
598
599impl Kernel for ConstantKernel {
600    fn compute_kernel_matrix(
601        &self,
602        X1: &Array2<f64>,
603        X2: Option<&Array2<f64>>,
604    ) -> SklResult<Array2<f64>> {
605        let n1 = X1.nrows();
606        let n2 = X2.map_or(n1, |x| x.nrows());
607        Ok(Array2::<f64>::from_elem((n1, n2), self.constant_value))
608    }
609
610    fn kernel(&self, _x1: &ArrayView1<f64>, _x2: &ArrayView1<f64>) -> f64 {
611        self.constant_value
612    }
613
614    fn get_params(&self) -> Vec<f64> {
615        vec![self.constant_value]
616    }
617
618    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
619        if params.len() != 1 {
620            return Err(SklearsError::InvalidInput(
621                "ConstantKernel requires exactly 1 parameter".to_string(),
622            ));
623        }
624        self.constant_value = params[0];
625        Ok(())
626    }
627
628    fn clone_box(&self) -> Box<dyn Kernel> {
629        Box::new(self.clone())
630    }
631}
632
633/// Sum kernel for composing kernels additively
634#[derive(Debug, Clone)]
635pub struct SumKernel {
636    kernels: Vec<Box<dyn Kernel>>,
637}
638
639impl SumKernel {
640    pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Self {
641        Self { kernels }
642    }
643}
644
645impl Kernel for SumKernel {
646    fn compute_kernel_matrix(
647        &self,
648        X1: &Array2<f64>,
649        X2: Option<&Array2<f64>>,
650    ) -> SklResult<Array2<f64>> {
651        if self.kernels.is_empty() {
652            return Err(SklearsError::InvalidInput(
653                "SumKernel requires at least one kernel".to_string(),
654            ));
655        }
656
657        let mut result = self.kernels[0].compute_kernel_matrix(X1, X2)?;
658        for kernel in &self.kernels[1..] {
659            let k_matrix = kernel.compute_kernel_matrix(X1, X2)?;
660            result = result + k_matrix;
661        }
662        Ok(result)
663    }
664
665    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
666        self.kernels.iter().map(|k| k.kernel(x1, x2)).sum()
667    }
668
669    fn get_params(&self) -> Vec<f64> {
670        self.kernels.iter().flat_map(|k| k.get_params()).collect()
671    }
672
673    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
674        let mut offset = 0;
675        for kernel in &mut self.kernels {
676            let n_params = kernel.get_params().len();
677            if offset + n_params > params.len() {
678                return Err(SklearsError::InvalidInput(
679                    "Not enough parameters for SumKernel".to_string(),
680                ));
681            }
682            kernel.set_params(&params[offset..offset + n_params])?;
683            offset += n_params;
684        }
685        Ok(())
686    }
687
688    fn clone_box(&self) -> Box<dyn Kernel> {
689        Box::new(Self {
690            kernels: self.kernels.iter().map(|k| k.clone_box()).collect(),
691        })
692    }
693}
694
695/// Product kernel for composing kernels multiplicatively
696#[derive(Debug, Clone)]
697pub struct ProductKernel {
698    kernels: Vec<Box<dyn Kernel>>,
699}
700
701impl ProductKernel {
702    pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Self {
703        Self { kernels }
704    }
705}
706
707impl Kernel for ProductKernel {
708    fn compute_kernel_matrix(
709        &self,
710        X1: &Array2<f64>,
711        X2: Option<&Array2<f64>>,
712    ) -> SklResult<Array2<f64>> {
713        if self.kernels.is_empty() {
714            return Err(SklearsError::InvalidInput(
715                "ProductKernel requires at least one kernel".to_string(),
716            ));
717        }
718
719        let mut result = self.kernels[0].compute_kernel_matrix(X1, X2)?;
720        for kernel in &self.kernels[1..] {
721            let k_matrix = kernel.compute_kernel_matrix(X1, X2)?;
722            result = result * k_matrix;
723        }
724        Ok(result)
725    }
726
727    fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
728        self.kernels.iter().map(|k| k.kernel(x1, x2)).product()
729    }
730
731    fn get_params(&self) -> Vec<f64> {
732        self.kernels.iter().flat_map(|k| k.get_params()).collect()
733    }
734
735    fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
736        let mut offset = 0;
737        for kernel in &mut self.kernels {
738            let n_params = kernel.get_params().len();
739            if offset + n_params > params.len() {
740                return Err(SklearsError::InvalidInput(
741                    "Not enough parameters for ProductKernel".to_string(),
742                ));
743            }
744            kernel.set_params(&params[offset..offset + n_params])?;
745            offset += n_params;
746        }
747        Ok(())
748    }
749
750    fn clone_box(&self) -> Box<dyn Kernel> {
751        Box::new(Self {
752            kernels: self.kernels.iter().map(|k| k.clone_box()).collect(),
753        })
754    }
755}
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760    // SciRS2 Policy - Use scirs2-autograd for ndarray types
761    use scirs2_core::ndarray::{array, Array1};
762
763    #[test]
764    fn test_ardrbf_creation() {
765        let length_scales = Array1::from_vec(vec![1.0, 2.0, 0.5]);
766        let kernel = ARDRBF::new(length_scales.clone());
767        assert_eq!(kernel.n_dimensions(), 3);
768        assert_eq!(kernel.get_params(), vec![1.0, 2.0, 0.5]);
769    }
770
771    #[test]
772    fn test_ardrbf_creation_uniform() {
773        let kernel = ARDRBF::new_uniform(4, 1.5);
774        assert_eq!(kernel.n_dimensions(), 4);
775        assert_eq!(kernel.get_params(), vec![1.5, 1.5, 1.5, 1.5]);
776    }
777
778    #[test]
779    fn test_ardrbf_kernel_identical_points() {
780        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
781        let kernel = ARDRBF::new(length_scales);
782
783        let x1 = array![1.0, 2.0];
784        let x2 = array![1.0, 2.0];
785
786        let k = kernel.kernel(&x1.view(), &x2.view());
787        assert!(
788            (k - 1.0).abs() < 1e-10,
789            "Kernel of identical points should be 1.0"
790        );
791    }
792
793    #[test]
794    fn test_ardrbf_kernel_different_points() {
795        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
796        let kernel = ARDRBF::new(length_scales);
797
798        let x1 = array![0.0, 0.0];
799        let x2 = array![1.0, 1.0];
800
801        let k = kernel.kernel(&x1.view(), &x2.view());
802        // Expected: exp(-0.5 * ((1.0^2 / 1.0^2) + (1.0^2 / 1.0^2))) = exp(-1.0) ≈ 0.368
803        assert!(k > 0.0 && k < 1.0, "Kernel should be between 0 and 1");
804        assert!((k - (-1.0f64).exp()).abs() < 1e-10, "Kernel value mismatch");
805    }
806
807    #[test]
808    fn test_ardrbf_kernel_matrix() {
809        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
810        let kernel = ARDRBF::new(length_scales);
811
812        let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
813        let k_matrix = kernel.compute_kernel_matrix(&x, None).unwrap();
814
815        assert_eq!(k_matrix.dim(), (3, 3));
816        // Diagonal should be 1.0
817        assert!((k_matrix[[0, 0]] - 1.0).abs() < 1e-10);
818        assert!((k_matrix[[1, 1]] - 1.0).abs() < 1e-10);
819        assert!((k_matrix[[2, 2]] - 1.0).abs() < 1e-10);
820        // Matrix should be symmetric
821        assert!((k_matrix[[0, 1]] - k_matrix[[1, 0]]).abs() < 1e-10);
822        assert!((k_matrix[[0, 2]] - k_matrix[[2, 0]]).abs() < 1e-10);
823        assert!((k_matrix[[1, 2]] - k_matrix[[2, 1]]).abs() < 1e-10);
824    }
825
826    #[test]
827    fn test_ardrbf_relevance_determination() {
828        // Test that different length scales weight dimensions differently
829        let length_scales = Array1::from_vec(vec![0.1, 10.0]); // First dimension very relevant, second not
830        let kernel = ARDRBF::new(length_scales);
831
832        let x1 = array![0.0, 0.0];
833        let x2_dim1 = array![1.0, 0.0]; // Change in first dimension
834        let x2_dim2 = array![0.0, 1.0]; // Change in second dimension
835
836        let k1 = kernel.kernel(&x1.view(), &x2_dim1.view());
837        let k2 = kernel.kernel(&x1.view(), &x2_dim2.view());
838
839        // Change in dimension 1 should have much more effect (smaller kernel value)
840        assert!(
841            k1 < k2,
842            "Dimension with smaller length scale should have more effect"
843        );
844    }
845
846    #[test]
847    fn test_ardrbf_set_params() {
848        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
849        let mut kernel = ARDRBF::new(length_scales);
850
851        let new_params = vec![2.0, 3.0];
852        kernel.set_params(&new_params).unwrap();
853
854        assert_eq!(kernel.get_params(), vec![2.0, 3.0]);
855    }
856
857    #[test]
858    fn test_ardrbf_set_params_wrong_size() {
859        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
860        let mut kernel = ARDRBF::new(length_scales);
861
862        let wrong_params = vec![1.0, 2.0, 3.0]; // Too many parameters
863        let result = kernel.set_params(&wrong_params);
864
865        assert!(
866            result.is_err(),
867            "Should error with wrong number of parameters"
868        );
869    }
870
871    #[test]
872    fn test_ardrbf_dimension_validation() {
873        let length_scales = Array1::from_vec(vec![1.0, 1.0]);
874        let kernel = ARDRBF::new(length_scales);
875
876        let x_wrong_dim = array![[0.0, 0.0, 0.0]]; // 3 dimensions instead of 2
877        let result = kernel.compute_kernel_matrix(&x_wrong_dim, None);
878
879        assert!(
880            result.is_err(),
881            "Should error with wrong number of dimensions"
882        );
883    }
884
885    #[test]
886    fn test_ardrbf_clone() {
887        let length_scales = Array1::from_vec(vec![1.0, 2.0]);
888        let kernel = ARDRBF::new(length_scales);
889        let cloned = kernel.clone();
890
891        assert_eq!(kernel.get_params(), cloned.get_params());
892        assert_eq!(kernel.n_dimensions(), cloned.n_dimensions());
893    }
894
895    #[test]
896    fn test_ardrbf_vs_rbf_isotropic() {
897        // When all length scales are equal, ARD RBF should behave like regular RBF
898        let length_scale = 1.5;
899        let ard_kernel = ARDRBF::new_uniform(2, length_scale);
900        let rbf_kernel = RBF::new(length_scale);
901
902        let x1 = array![1.0, 2.0];
903        let x2 = array![3.0, 4.0];
904
905        let k_ard = ard_kernel.kernel(&x1.view(), &x2.view());
906        let k_rbf = rbf_kernel.kernel(&x1.view(), &x2.view());
907
908        assert!(
909            (k_ard - k_rbf).abs() < 1e-10,
910            "ARD RBF with uniform length scales should match RBF"
911        );
912    }
913}