sklears_svm/
kernels.rs

1//! Kernel functions for Support Vector Machines
2//!
3//! This module provides comprehensive kernel implementations for SVM including basic kernels,
4//! composite kernels, graph kernels, and advanced kernel methods.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10/// Kernel trait for all kernel functions
11pub trait Kernel: Send + Sync + std::fmt::Debug {
12    /// Compute kernel value between two vectors
13    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64;
14
15    /// Compute kernel matrix for two datasets
16    fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
17        let (n_x, _) = x.dim();
18        let (n_y, _) = y.dim();
19        let mut kernel_matrix = Array2::zeros((n_x, n_y));
20
21        for i in 0..n_x {
22            for j in 0..n_y {
23                kernel_matrix[[i, j]] = self.compute(x.row(i), y.row(j));
24            }
25        }
26
27        kernel_matrix
28    }
29
30    /// Get kernel parameters for serialization
31    fn parameters(&self) -> HashMap<String, f64>;
32}
33
34/// Main kernel type enumeration
35#[derive(Debug, Clone, PartialEq)]
36pub enum KernelType {
37    /// Linear kernel: K(x,y) = x^T y
38    Linear,
39    /// RBF/Gaussian kernel: K(x,y) = exp(-γ||x-y||²)
40    Rbf { gamma: f64 },
41    /// Polynomial kernel: K(x,y) = (γ x^T y + r)^d
42    Polynomial { gamma: f64, coef0: f64, degree: f64 },
43    /// Sigmoid kernel: K(x,y) = tanh(γ x^T y + r)
44    Sigmoid { gamma: f64, coef0: f64 },
45    /// Precomputed kernel matrix
46    Precomputed,
47    /// Custom user-defined kernel
48    Custom(String),
49    /// Cosine similarity kernel
50    Cosine,
51    /// Chi-squared kernel
52    ChiSquared { gamma: f64 },
53    /// Histogram intersection kernel
54    Intersection,
55    /// Hellinger kernel
56    Hellinger,
57    /// Jensen-Shannon kernel
58    JensenShannon,
59    /// Periodic kernel
60    Periodic { length_scale: f64, period: f64 },
61}
62
63/// Create a kernel instance from a KernelType
64pub fn create_kernel(kernel_type: KernelType) -> Box<dyn Kernel> {
65    match kernel_type {
66        KernelType::Linear => Box::new(LinearKernel),
67        KernelType::Rbf { gamma } => Box::new(RbfKernel { gamma }),
68        KernelType::Polynomial {
69            gamma,
70            coef0,
71            degree,
72        } => Box::new(PolynomialKernel {
73            gamma,
74            coef0,
75            degree,
76        }),
77        KernelType::Sigmoid { gamma, coef0 } => Box::new(SigmoidKernel { gamma, coef0 }),
78        KernelType::Cosine => Box::new(CosineKernel),
79        KernelType::ChiSquared { gamma } => Box::new(ChiSquaredKernel { gamma }),
80        KernelType::Intersection => Box::new(IntersectionKernel),
81        KernelType::Hellinger => Box::new(HellingerKernel),
82        KernelType::JensenShannon => Box::new(JensenShannonKernel),
83        KernelType::Periodic {
84            length_scale,
85            period,
86        } => Box::new(PeriodicKernel {
87            length_scale,
88            period,
89        }),
90        KernelType::Precomputed => panic!("Precomputed kernels must be created with data"),
91        KernelType::Custom(name) => panic!("Custom kernel '{}' not implemented", name),
92    }
93}
94
95impl<K: Kernel> Kernel for Arc<K> {
96    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
97        (**self).compute(x, y)
98    }
99
100    fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
101        (**self).compute_matrix(x, y)
102    }
103
104    fn parameters(&self) -> HashMap<String, f64> {
105        (**self).parameters()
106    }
107}
108
109/// Linear kernel implementation
110#[derive(Debug, Clone)]
111pub struct LinearKernel;
112
113impl Default for LinearKernel {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl LinearKernel {
120    pub fn new() -> Self {
121        Self
122    }
123}
124
125impl Kernel for LinearKernel {
126    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
127        x.dot(&y)
128    }
129
130    fn parameters(&self) -> HashMap<String, f64> {
131        HashMap::new()
132    }
133}
134
135/// RBF (Gaussian) kernel implementation
136#[derive(Debug, Clone)]
137pub struct RbfKernel {
138    pub gamma: f64,
139}
140
141impl RbfKernel {
142    pub fn new(gamma: f64) -> Self {
143        Self { gamma }
144    }
145}
146
147impl Kernel for RbfKernel {
148    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
149        let diff = &x.to_owned() - &y.to_owned();
150        let squared_distance = diff.dot(&diff);
151        (-self.gamma * squared_distance).exp()
152    }
153
154    fn parameters(&self) -> HashMap<String, f64> {
155        let mut params = HashMap::new();
156        params.insert("gamma".to_string(), self.gamma);
157        params
158    }
159}
160
161/// Polynomial kernel implementation
162#[derive(Debug, Clone)]
163pub struct PolynomialKernel {
164    pub gamma: f64,
165    pub coef0: f64,
166    pub degree: f64,
167}
168
169impl PolynomialKernel {
170    pub fn new(gamma: f64, coef0: f64, degree: f64) -> Self {
171        Self {
172            gamma,
173            coef0,
174            degree,
175        }
176    }
177}
178
179impl Kernel for PolynomialKernel {
180    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
181        let dot_product = x.dot(&y);
182        (self.gamma * dot_product + self.coef0).powf(self.degree)
183    }
184
185    fn parameters(&self) -> HashMap<String, f64> {
186        let mut params = HashMap::new();
187        params.insert("gamma".to_string(), self.gamma);
188        params.insert("coef0".to_string(), self.coef0);
189        params.insert("degree".to_string(), self.degree);
190        params
191    }
192}
193
194/// Sigmoid kernel implementation
195#[derive(Debug, Clone)]
196pub struct SigmoidKernel {
197    pub gamma: f64,
198    pub coef0: f64,
199}
200
201impl SigmoidKernel {
202    pub fn new(gamma: f64, coef0: f64) -> Self {
203        Self { gamma, coef0 }
204    }
205}
206
207impl Kernel for SigmoidKernel {
208    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
209        let dot_product = x.dot(&y);
210        (self.gamma * dot_product + self.coef0).tanh()
211    }
212
213    fn parameters(&self) -> HashMap<String, f64> {
214        let mut params = HashMap::new();
215        params.insert("gamma".to_string(), self.gamma);
216        params.insert("coef0".to_string(), self.coef0);
217        params
218    }
219}
220
221/// Cosine similarity kernel implementation
222#[derive(Debug, Clone)]
223pub struct CosineKernel;
224
225impl Kernel for CosineKernel {
226    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
227        let dot_product = x.dot(&y);
228        let x_norm = x.dot(&x).sqrt();
229        let y_norm = y.dot(&y).sqrt();
230
231        if x_norm == 0.0 || y_norm == 0.0 {
232            0.0
233        } else {
234            dot_product / (x_norm * y_norm)
235        }
236    }
237
238    fn parameters(&self) -> HashMap<String, f64> {
239        HashMap::new()
240    }
241}
242
243/// Chi-squared kernel implementation
244#[derive(Debug, Clone)]
245pub struct ChiSquaredKernel {
246    pub gamma: f64,
247}
248
249impl ChiSquaredKernel {
250    pub fn new(gamma: f64) -> Self {
251        Self { gamma }
252    }
253}
254
255impl Kernel for ChiSquaredKernel {
256    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
257        let chi_squared_distance = x
258            .iter()
259            .zip(y.iter())
260            .map(|(a, b)| {
261                if a + b > 0.0 {
262                    (a - b).powi(2) / (a + b)
263                } else {
264                    0.0
265                }
266            })
267            .sum::<f64>();
268
269        (-self.gamma * chi_squared_distance).exp()
270    }
271
272    fn parameters(&self) -> HashMap<String, f64> {
273        let mut params = HashMap::new();
274        params.insert("gamma".to_string(), self.gamma);
275        params
276    }
277}
278
279/// Histogram intersection kernel implementation
280#[derive(Debug, Clone)]
281pub struct IntersectionKernel;
282
283impl Kernel for IntersectionKernel {
284    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
285        x.iter().zip(y.iter()).map(|(a, b)| a.min(*b)).sum()
286    }
287
288    fn parameters(&self) -> HashMap<String, f64> {
289        HashMap::new()
290    }
291}
292
293/// Periodic kernel implementation
294#[derive(Debug, Clone)]
295pub struct PeriodicKernel {
296    pub length_scale: f64,
297    pub period: f64,
298}
299
300impl PeriodicKernel {
301    pub fn new(length_scale: f64, period: f64) -> Self {
302        Self {
303            length_scale,
304            period,
305        }
306    }
307}
308
309impl Kernel for PeriodicKernel {
310    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
311        let diff = &x.to_owned() - &y.to_owned();
312        let sin_term = diff.mapv(|d| (std::f64::consts::PI * d / self.period).sin());
313        let sin_squared = sin_term.dot(&sin_term);
314        (-2.0 * sin_squared / (self.length_scale * self.length_scale)).exp()
315    }
316
317    fn parameters(&self) -> HashMap<String, f64> {
318        let mut params = HashMap::new();
319        params.insert("length_scale".to_string(), self.length_scale);
320        params.insert("period".to_string(), self.period);
321        params
322    }
323}
324
325/// Custom kernel implementation
326#[derive(Debug, Clone)]
327pub struct CustomKernel {
328    pub name: String,
329    pub function: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
330}
331
332impl CustomKernel {
333    pub fn new(name: String, function: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64) -> Self {
334        Self { name, function }
335    }
336}
337
338impl Kernel for CustomKernel {
339    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
340        (self.function)(x, y)
341    }
342
343    fn parameters(&self) -> HashMap<String, f64> {
344        HashMap::new()
345    }
346}
347
348/// Main kernel function wrapper
349#[derive(Debug, Clone)]
350pub struct KernelFunction {
351    kernel_type: KernelType,
352}
353
354impl KernelFunction {
355    pub fn new(kernel_type: KernelType) -> Self {
356        Self { kernel_type }
357    }
358
359    pub fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
360        match &self.kernel_type {
361            KernelType::Linear => LinearKernel.compute(x, y),
362            KernelType::Rbf { gamma } => RbfKernel::new(*gamma).compute(x, y),
363            KernelType::Polynomial {
364                gamma,
365                coef0,
366                degree,
367            } => PolynomialKernel::new(*gamma, *coef0, *degree).compute(x, y),
368            KernelType::Sigmoid { gamma, coef0 } => {
369                SigmoidKernel::new(*gamma, *coef0).compute(x, y)
370            }
371            KernelType::Cosine => CosineKernel.compute(x, y),
372            KernelType::ChiSquared { gamma } => ChiSquaredKernel::new(*gamma).compute(x, y),
373            KernelType::Intersection => IntersectionKernel.compute(x, y),
374            KernelType::Periodic {
375                length_scale,
376                period,
377            } => PeriodicKernel::new(*length_scale, *period).compute(x, y),
378            KernelType::Precomputed => {
379                // For precomputed kernels, we assume x and y contain indices
380                0.0 // Placeholder - actual implementation would look up precomputed values
381            }
382            KernelType::Custom(_name) => {
383                // Default custom implementation
384                x.dot(&y)
385            }
386            KernelType::Hellinger => {
387                // Hellinger kernel (Bhattacharyya coefficient)
388                let x_normalized = normalize_vector(&x.to_owned());
389                let y_normalized = normalize_vector(&y.to_owned());
390                x_normalized
391                    .iter()
392                    .zip(y_normalized.iter())
393                    .map(|(a, b)| (a * b).sqrt())
394                    .sum::<f64>()
395                    .sqrt()
396            }
397            KernelType::JensenShannon => {
398                // Jensen-Shannon kernel
399                let x_normalized = normalize_vector(&x.to_owned());
400                let y_normalized = normalize_vector(&y.to_owned());
401
402                let mut js_divergence = 0.0;
403                for i in 0..x_normalized.len() {
404                    let p = x_normalized[i];
405                    let q = y_normalized[i];
406                    let m = (p + q) / 2.0;
407
408                    if p > 0.0 && m > 0.0 {
409                        js_divergence += p * (p / m).ln();
410                    }
411                    if q > 0.0 && m > 0.0 {
412                        js_divergence += q * (q / m).ln();
413                    }
414                }
415                js_divergence /= 2.0;
416
417                (-js_divergence).exp()
418            }
419        }
420    }
421
422    pub fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
423        let (n_x, _) = x.dim();
424        let (n_y, _) = y.dim();
425        let mut kernel_matrix = Array2::zeros((n_x, n_y));
426
427        for i in 0..n_x {
428            for j in 0..n_y {
429                kernel_matrix[[i, j]] = self.compute(x.row(i), y.row(j));
430            }
431        }
432
433        kernel_matrix
434    }
435
436    pub fn kernel_type(&self) -> &KernelType {
437        &self.kernel_type
438    }
439}
440
441/// Utility function to normalize a vector
442fn normalize_vector(vec: &Array1<f64>) -> Array1<f64> {
443    let sum: f64 = vec.iter().sum();
444    if sum == 0.0 {
445        vec.clone()
446    } else {
447        vec / sum
448    }
449}
450
451/// Graph structure for graph kernels
452#[derive(Debug, Clone)]
453pub struct Graph {
454    pub adjacency_matrix: Array2<f64>,
455    pub node_labels: Option<Array1<usize>>,
456    pub edge_labels: Option<Array2<usize>>,
457}
458
459impl Graph {
460    pub fn new(adjacency_matrix: Array2<f64>) -> Self {
461        Self {
462            adjacency_matrix,
463            node_labels: None,
464            edge_labels: None,
465        }
466    }
467
468    pub fn with_node_labels(mut self, labels: Array1<usize>) -> Self {
469        self.node_labels = Some(labels);
470        self
471    }
472
473    pub fn with_edge_labels(mut self, labels: Array2<usize>) -> Self {
474        self.edge_labels = Some(labels);
475        self
476    }
477}
478
479/// Random Walk kernel for graphs
480#[derive(Debug, Clone)]
481pub struct RandomWalkKernel {
482    pub lambda: f64, // decay parameter
483    pub max_steps: usize,
484}
485
486impl RandomWalkKernel {
487    pub fn new(lambda: f64, max_steps: usize) -> Self {
488        Self { lambda, max_steps }
489    }
490
491    pub fn compute_graph_kernel(&self, g1: &Graph, g2: &Graph) -> f64 {
492        // Simplified random walk kernel computation
493        // In practice, this would involve computing the product graph
494        // and solving the linear system (I - λA_product)^-1
495
496        let n1 = g1.adjacency_matrix.nrows();
497        let n2 = g2.adjacency_matrix.nrows();
498
499        // Placeholder implementation - compute similarity based on graph sizes and densities
500        let density1 = g1.adjacency_matrix.sum() / (n1 * n1) as f64;
501        let density2 = g2.adjacency_matrix.sum() / (n2 * n2) as f64;
502
503        (-(density1 - density2).abs()).exp()
504    }
505}
506
507/// Hellinger kernel implementation
508#[derive(Debug)]
509pub struct HellingerKernel;
510
511impl Kernel for HellingerKernel {
512    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
513        // Hellinger kernel: K(x,y) = sum(sqrt(x_i * y_i))
514        x.iter()
515            .zip(y.iter())
516            .map(|(xi, yi)| (xi * yi).sqrt())
517            .sum()
518    }
519
520    fn parameters(&self) -> HashMap<String, f64> {
521        HashMap::new()
522    }
523}
524
525/// Jensen-Shannon kernel implementation
526#[derive(Debug)]
527pub struct JensenShannonKernel;
528
529impl JensenShannonKernel {
530    fn jensen_shannon_divergence(&self, p: ArrayView1<f64>, q: ArrayView1<f64>) -> f64 {
531        // Jensen-Shannon divergence
532        let m: Vec<f64> = p
533            .iter()
534            .zip(q.iter())
535            .map(|(pi, qi)| 0.5 * (pi + qi))
536            .collect();
537        let m = Array1::from_vec(m);
538
539        let kl_pm = self.kl_divergence(p, m.view());
540        let kl_qm = self.kl_divergence(q, m.view());
541
542        0.5 * kl_pm + 0.5 * kl_qm
543    }
544
545    fn kl_divergence(&self, p: ArrayView1<f64>, q: ArrayView1<f64>) -> f64 {
546        // Kullback-Leibler divergence
547        p.iter()
548            .zip(q.iter())
549            .map(|(pi, qi)| {
550                if *pi > 0.0 && *qi > 0.0 {
551                    pi * (pi / qi).ln()
552                } else {
553                    0.0
554                }
555            })
556            .sum()
557    }
558}
559
560impl Kernel for JensenShannonKernel {
561    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
562        // Jensen-Shannon kernel: K(x,y) = exp(-JS(x,y))
563        let js_div = self.jensen_shannon_divergence(x, y);
564        (-js_div).exp()
565    }
566
567    fn parameters(&self) -> HashMap<String, f64> {
568        HashMap::new()
569    }
570}
571
572/// Implement Kernel trait for KernelType to enable direct usage
573impl Kernel for KernelType {
574    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
575        match self {
576            KernelType::Linear => LinearKernel.compute(x, y),
577            KernelType::Rbf { gamma } => RbfKernel::new(*gamma).compute(x, y),
578            KernelType::Polynomial {
579                gamma,
580                coef0,
581                degree,
582            } => PolynomialKernel::new(*gamma, *coef0, *degree).compute(x, y),
583            KernelType::Sigmoid { gamma, coef0 } => {
584                SigmoidKernel::new(*gamma, *coef0).compute(x, y)
585            }
586            KernelType::Cosine => CosineKernel.compute(x, y),
587            KernelType::ChiSquared { gamma } => ChiSquaredKernel::new(*gamma).compute(x, y),
588            KernelType::Intersection => IntersectionKernel.compute(x, y),
589            KernelType::Hellinger => HellingerKernel.compute(x, y),
590            KernelType::JensenShannon => JensenShannonKernel.compute(x, y),
591            KernelType::Periodic {
592                length_scale,
593                period,
594            } => PeriodicKernel::new(*length_scale, *period).compute(x, y),
595            KernelType::Precomputed => {
596                // For precomputed kernels, we assume x and y contain indices
597                0.0 // Placeholder - actual implementation would look up precomputed values
598            }
599            KernelType::Custom(_name) => {
600                // Default custom implementation
601                x.dot(&y)
602            }
603        }
604    }
605
606    fn parameters(&self) -> HashMap<String, f64> {
607        match self {
608            KernelType::Linear => HashMap::new(),
609            KernelType::Rbf { gamma } => {
610                let mut params = HashMap::new();
611                params.insert("gamma".to_string(), *gamma);
612                params
613            }
614            KernelType::Polynomial {
615                gamma,
616                coef0,
617                degree,
618            } => {
619                let mut params = HashMap::new();
620                params.insert("gamma".to_string(), *gamma);
621                params.insert("coef0".to_string(), *coef0);
622                params.insert("degree".to_string(), *degree);
623                params
624            }
625            KernelType::Sigmoid { gamma, coef0 } => {
626                let mut params = HashMap::new();
627                params.insert("gamma".to_string(), *gamma);
628                params.insert("coef0".to_string(), *coef0);
629                params
630            }
631            KernelType::Cosine => HashMap::new(),
632            KernelType::ChiSquared { gamma } => {
633                let mut params = HashMap::new();
634                params.insert("gamma".to_string(), *gamma);
635                params
636            }
637            KernelType::Intersection => HashMap::new(),
638            KernelType::Hellinger => HashMap::new(),
639            KernelType::JensenShannon => HashMap::new(),
640            KernelType::Periodic {
641                length_scale,
642                period,
643            } => {
644                let mut params = HashMap::new();
645                params.insert("length_scale".to_string(), *length_scale);
646                params.insert("period".to_string(), *period);
647                params
648            }
649            KernelType::Precomputed => HashMap::new(),
650            KernelType::Custom(_name) => HashMap::new(),
651        }
652    }
653}
654
655/// Implement Kernel trait for Box<dyn Kernel> to enable polymorphic usage
656impl Kernel for Box<dyn Kernel> {
657    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
658        self.as_ref().compute(x, y)
659    }
660
661    fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
662        self.as_ref().compute_matrix(x, y)
663    }
664
665    fn parameters(&self) -> HashMap<String, f64> {
666        self.as_ref().parameters()
667    }
668}
669
670#[allow(non_snake_case)]
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use approx::assert_abs_diff_eq;
675
676    #[test]
677    fn test_linear_kernel() {
678        let kernel = LinearKernel;
679        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
680        let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
681
682        let result = kernel.compute(x.view(), y.view());
683        assert_abs_diff_eq!(result, 32.0, epsilon = 1e-10);
684    }
685
686    #[test]
687    fn test_rbf_kernel() {
688        let kernel = RbfKernel::new(1.0);
689        let x = Array1::from_vec(vec![1.0, 2.0]);
690        let y = Array1::from_vec(vec![1.0, 2.0]);
691
692        let result = kernel.compute(x.view(), y.view());
693        assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10);
694    }
695
696    #[test]
697    fn test_polynomial_kernel() {
698        let kernel = PolynomialKernel::new(1.0, 1.0, 2.0);
699        let x = Array1::from_vec(vec![1.0, 2.0]);
700        let y = Array1::from_vec(vec![3.0, 4.0]);
701
702        let result = kernel.compute(x.view(), y.view());
703        let expected = (1.0_f64 * (1.0 * 3.0 + 2.0 * 4.0) + 1.0).powf(2.0);
704        assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
705    }
706
707    #[test]
708    fn test_cosine_kernel() {
709        let kernel = CosineKernel;
710        let x = Array1::from_vec(vec![1.0, 0.0]);
711        let y = Array1::from_vec(vec![0.0, 1.0]);
712
713        let result = kernel.compute(x.view(), y.view());
714        assert_abs_diff_eq!(result, 0.0, epsilon = 1e-10);
715    }
716
717    #[test]
718    fn test_kernel_function() {
719        let kernel_fn = KernelFunction::new(KernelType::Rbf { gamma: 0.5 });
720        let x = Array1::from_vec(vec![1.0, 2.0]);
721        let y = Array1::from_vec(vec![1.0, 2.0]);
722
723        let result = kernel_fn.compute(x.view(), y.view());
724        assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10);
725    }
726
727    #[test]
728    fn test_kernel_matrix() {
729        let kernel_fn = KernelFunction::new(KernelType::Linear);
730        let x = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
731        let y = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
732
733        let kernel_matrix = kernel_fn.compute_matrix(&x, &y);
734
735        assert_eq!(kernel_matrix.dim(), (2, 2));
736        assert_abs_diff_eq!(kernel_matrix[[0, 0]], 5.0, epsilon = 1e-10); // [1,2] · [1,2] = 5
737        assert_abs_diff_eq!(kernel_matrix[[1, 1]], 25.0, epsilon = 1e-10); // [3,4] · [3,4] = 25
738    }
739}