sklears_kernel_approximation/
simple_test.rs1use crate::nystroem::*;
4use crate::rbf_sampler::*;
5use scirs2_core::ndarray::array;
6use sklears_core::traits::{Fit, Transform};
7
8pub fn test_implementations() {
9 println!("Testing kernel approximation implementations...");
10
11 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
13
14 println!("Testing RBF Sampler...");
16 match test_rbf_sampler(&x) {
17 Ok(_) => println!("✓ RBF Sampler works"),
18 Err(e) => println!("✗ RBF Sampler failed: {}", e),
19 }
20
21 println!("Testing Laplacian Sampler...");
23 match test_laplacian_sampler(&x) {
24 Ok(_) => println!("✓ Laplacian Sampler works"),
25 Err(e) => println!("✗ Laplacian Sampler failed: {}", e),
26 }
27
28 println!("Testing Polynomial Sampler...");
30 match test_polynomial_sampler(&x) {
31 Ok(_) => println!("✓ Polynomial Sampler works"),
32 Err(e) => println!("✗ Polynomial Sampler failed: {}", e),
33 }
34
35 println!("Testing Arc-cosine Sampler...");
37 match test_arc_cosine_sampler(&x) {
38 Ok(_) => println!("✓ Arc-cosine Sampler works"),
39 Err(e) => println!("✗ Arc-cosine Sampler failed: {}", e),
40 }
41
42 println!("Testing Nyström Method...");
44 match test_nystroem(&x) {
45 Ok(_) => println!("✓ Nyström Method works"),
46 Err(e) => println!("✗ Nyström Method failed: {}", e),
47 }
48
49 println!("Testing complete!");
50}
51
52fn test_rbf_sampler(
53 x: &scirs2_core::ndarray::Array2<f64>,
54) -> Result<(), Box<dyn std::error::Error>> {
55 let rbf = RBFSampler::new(10).gamma(0.1);
56 let fitted = rbf.fit(x, &())?;
57 let result = fitted.transform(x)?;
58 assert_eq!(result.shape(), &[4, 10]);
59 Ok(())
60}
61
62fn test_laplacian_sampler(
63 x: &scirs2_core::ndarray::Array2<f64>,
64) -> Result<(), Box<dyn std::error::Error>> {
65 let laplacian = LaplacianSampler::new(10).gamma(0.1);
66 let fitted = laplacian.fit(x, &())?;
67 let result = fitted.transform(x)?;
68 assert_eq!(result.shape(), &[4, 10]);
69 Ok(())
70}
71
72fn test_polynomial_sampler(
73 x: &scirs2_core::ndarray::Array2<f64>,
74) -> Result<(), Box<dyn std::error::Error>> {
75 let poly = PolynomialSampler::new(10).degree(2).gamma(1.0).coef0(1.0);
76 let fitted = poly.fit(x, &())?;
77 let result = fitted.transform(x)?;
78 assert_eq!(result.shape(), &[4, 10]);
79 Ok(())
80}
81
82fn test_arc_cosine_sampler(
83 x: &scirs2_core::ndarray::Array2<f64>,
84) -> Result<(), Box<dyn std::error::Error>> {
85 let arc_cosine = ArcCosineSampler::new(10).degree(1);
86 let fitted = arc_cosine.fit(x, &())?;
87 let result = fitted.transform(x)?;
88 assert_eq!(result.shape(), &[4, 10]);
89 Ok(())
90}
91
92fn test_nystroem(x: &scirs2_core::ndarray::Array2<f64>) -> Result<(), Box<dyn std::error::Error>> {
93 let nystroem = Nystroem::new(Kernel::Linear, 3).sampling_strategy(SamplingStrategy::Random);
95 let fitted = nystroem.fit(x, &())?;
96 let result = fitted.transform(x)?;
97 assert_eq!(result.nrows(), 4);
98
99 let nystroem_rbf = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 3)
101 .sampling_strategy(SamplingStrategy::LeverageScore);
102 let fitted_rbf = nystroem_rbf.fit(x, &())?;
103 let result_rbf = fitted_rbf.transform(x)?;
104 assert_eq!(result_rbf.nrows(), 4);
105
106 Ok(())
107}