1use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use scirs2_core::random::Rng;
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14use scirs2_core::random::{thread_rng, SeedableRng};
15pub trait KernelFunction: Clone + Send + Sync {
17 fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
19
20 fn fourier_transform(&self, w: &[Float]) -> Float;
24
25 fn sample_frequencies(
29 &self,
30 n_features: usize,
31 n_components: usize,
32 rng: &mut RealStdRng,
33 ) -> Array2<Float>;
34
35 fn description(&self) -> String;
37}
38
39#[derive(Debug, Clone)]
41pub struct CustomRBFKernel {
43 pub gamma: Float,
45 pub sigma: Float,
47}
48
49impl CustomRBFKernel {
50 pub fn new(gamma: Float) -> Self {
51 Self {
52 gamma,
53 sigma: (1.0 / (2.0 * gamma)).sqrt(),
54 }
55 }
56
57 pub fn from_sigma(sigma: Float) -> Self {
58 let gamma = 1.0 / (2.0 * sigma * sigma);
59 Self { gamma, sigma }
60 }
61}
62
63impl KernelFunction for CustomRBFKernel {
64 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
65 let dist_sq: Float = x
66 .iter()
67 .zip(y.iter())
68 .map(|(xi, yi)| (xi - yi).powi(2))
69 .sum();
70 (-self.gamma * dist_sq).exp()
71 }
72
73 fn fourier_transform(&self, w: &[Float]) -> Float {
74 let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
75 (-w_norm_sq / (4.0 * self.gamma)).exp()
76 }
77
78 fn sample_frequencies(
79 &self,
80 n_features: usize,
81 n_components: usize,
82 rng: &mut RealStdRng,
83 ) -> Array2<Float> {
84 let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
85 let mut weights = Array2::zeros((n_features, n_components));
86 for mut col in weights.columns_mut() {
87 for val in col.iter_mut() {
88 *val = normal.sample(rng);
89 }
90 }
91 weights
92 }
93
94 fn description(&self) -> String {
95 format!("Custom RBF kernel with gamma={}", self.gamma)
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct CustomPolynomialKernel {
103 pub gamma: Float,
105 pub coef0: Float,
107 pub degree: u32,
109}
110
111impl CustomPolynomialKernel {
112 pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
113 Self {
114 gamma,
115 coef0,
116 degree,
117 }
118 }
119}
120
121impl KernelFunction for CustomPolynomialKernel {
122 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
123 let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
124 (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
125 }
126
127 fn fourier_transform(&self, w: &[Float]) -> Float {
128 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
131 (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
132 }
133
134 fn sample_frequencies(
135 &self,
136 n_features: usize,
137 n_components: usize,
138 rng: &mut RealStdRng,
139 ) -> Array2<Float> {
140 let normal = RandNormal::new(0.0, self.gamma.sqrt()).unwrap();
142 let mut weights = Array2::zeros((n_features, n_components));
143 for mut col in weights.columns_mut() {
144 for val in col.iter_mut() {
145 *val = normal.sample(rng);
146 }
147 }
148 weights
149 }
150
151 fn description(&self) -> String {
152 format!(
153 "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
154 self.degree, self.gamma, self.coef0
155 )
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct CustomLaplacianKernel {
163 pub gamma: Float,
165}
166
167impl CustomLaplacianKernel {
168 pub fn new(gamma: Float) -> Self {
169 Self { gamma }
170 }
171}
172
173impl KernelFunction for CustomLaplacianKernel {
174 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
175 let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
176 (-self.gamma * l1_dist).exp()
177 }
178
179 fn fourier_transform(&self, w: &[Float]) -> Float {
180 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
181 self.gamma / (self.gamma + w_norm).powi(2)
182 }
183
184 fn sample_frequencies(
185 &self,
186 n_features: usize,
187 n_components: usize,
188 rng: &mut RealStdRng,
189 ) -> Array2<Float> {
190 use scirs2_core::random::Cauchy;
191 let cauchy = Cauchy::new(0.0, self.gamma).unwrap();
192 let mut weights = Array2::zeros((n_features, n_components));
193 for mut col in weights.columns_mut() {
194 for val in col.iter_mut() {
195 *val = cauchy.sample(rng);
196 }
197 }
198 weights
199 }
200
201 fn description(&self) -> String {
202 format!("Custom Laplacian kernel with gamma={}", self.gamma)
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct CustomExponentialKernel {
210 pub length_scale: Float,
212}
213
214impl CustomExponentialKernel {
215 pub fn new(length_scale: Float) -> Self {
216 Self { length_scale }
217 }
218}
219
220impl KernelFunction for CustomExponentialKernel {
221 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
222 let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
223 (-dist / self.length_scale).exp()
224 }
225
226 fn fourier_transform(&self, w: &[Float]) -> Float {
227 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
228 2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
229 }
230
231 fn sample_frequencies(
232 &self,
233 n_features: usize,
234 n_components: usize,
235 rng: &mut RealStdRng,
236 ) -> Array2<Float> {
237 use scirs2_core::random::Cauchy;
238 let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).unwrap();
239 let mut weights = Array2::zeros((n_features, n_components));
240 for mut col in weights.columns_mut() {
241 for val in col.iter_mut() {
242 *val = cauchy.sample(rng);
243 }
244 }
245 weights
246 }
247
248 fn description(&self) -> String {
249 format!(
250 "Custom Exponential kernel with length_scale={}",
251 self.length_scale
252 )
253 }
254}
255
256#[derive(Debug, Clone)]
284pub struct CustomKernelSampler<K, State = Untrained>
286where
287 K: KernelFunction,
288{
289 pub kernel: K,
291 pub n_components: usize,
293 pub random_state: Option<u64>,
295
296 random_weights_: Option<Array2<Float>>,
298 random_offset_: Option<Array1<Float>>,
299
300 _state: PhantomData<State>,
301}
302
303impl<K> CustomKernelSampler<K, Untrained>
304where
305 K: KernelFunction,
306{
307 pub fn new(kernel: K, n_components: usize) -> Self {
309 Self {
310 kernel,
311 n_components,
312 random_state: None,
313 random_weights_: None,
314 random_offset_: None,
315 _state: PhantomData,
316 }
317 }
318
319 pub fn random_state(mut self, seed: u64) -> Self {
321 self.random_state = Some(seed);
322 self
323 }
324}
325
326impl<K> Estimator for CustomKernelSampler<K, Untrained>
327where
328 K: KernelFunction,
329{
330 type Config = ();
331 type Error = SklearsError;
332 type Float = Float;
333
334 fn config(&self) -> &Self::Config {
335 &()
336 }
337}
338
339impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
340where
341 K: KernelFunction,
342{
343 type Fitted = CustomKernelSampler<K, Trained>;
344
345 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
346 let (_, n_features) = x.dim();
347
348 if self.n_components == 0 {
349 return Err(SklearsError::InvalidInput(
350 "n_components must be positive".to_string(),
351 ));
352 }
353
354 let mut rng = if let Some(seed) = self.random_state {
355 RealStdRng::seed_from_u64(seed)
356 } else {
357 RealStdRng::from_seed(thread_rng().gen())
358 };
359
360 let random_weights =
362 self.kernel
363 .sample_frequencies(n_features, self.n_components, &mut rng);
364
365 let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
367 let mut random_offset = Array1::zeros(self.n_components);
368 for val in random_offset.iter_mut() {
369 *val = rng.sample(uniform);
370 }
371
372 Ok(CustomKernelSampler {
373 kernel: self.kernel,
374 n_components: self.n_components,
375 random_state: self.random_state,
376 random_weights_: Some(random_weights),
377 random_offset_: Some(random_offset),
378 _state: PhantomData,
379 })
380 }
381}
382
383impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
384where
385 K: KernelFunction,
386{
387 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
388 let (_n_samples, n_features) = x.dim();
389 let weights = self.random_weights_.as_ref().unwrap();
390 let offset = self.random_offset_.as_ref().unwrap();
391
392 if n_features != weights.nrows() {
393 return Err(SklearsError::InvalidInput(format!(
394 "X has {} features, but CustomKernelSampler was fitted with {} features",
395 n_features,
396 weights.nrows()
397 )));
398 }
399
400 let projection = x.dot(weights) + offset.view().insert_axis(Axis(0));
402
403 let normalization = (2.0 / self.n_components as Float).sqrt();
405 let result = projection.mapv(|v| normalization * v.cos());
406
407 Ok(result)
408 }
409}
410
411impl<K> CustomKernelSampler<K, Trained>
412where
413 K: KernelFunction,
414{
415 pub fn random_weights(&self) -> &Array2<Float> {
417 self.random_weights_.as_ref().unwrap()
418 }
419
420 pub fn random_offset(&self) -> &Array1<Float> {
422 self.random_offset_.as_ref().unwrap()
423 }
424
425 pub fn kernel_description(&self) -> String {
427 self.kernel.description()
428 }
429
430 pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
432 let (n_x, _) = x.dim();
433 let (n_y, _) = y.dim();
434 let mut kernel_matrix = Array2::zeros((n_x, n_y));
435
436 for i in 0..n_x {
437 for j in 0..n_y {
438 let x_row = x.row(i).to_vec();
439 let y_row = y.row(j).to_vec();
440 kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
441 }
442 }
443
444 kernel_matrix
445 }
446}
447
448#[allow(non_snake_case)]
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use approx::assert_abs_diff_eq;
453 use scirs2_core::ndarray::array;
454
455 #[test]
456 fn test_custom_rbf_kernel() {
457 let kernel = CustomRBFKernel::new(0.5);
458 let x = vec![1.0, 2.0];
459 let y = vec![1.0, 2.0];
460
461 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
462
463 let y2 = vec![2.0, 3.0];
464 let expected = (-0.5_f64 * 2.0).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
466 }
467
468 #[test]
469 fn test_custom_polynomial_kernel() {
470 let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
471 let x = vec![1.0, 2.0];
472 let y = vec![2.0, 3.0];
473
474 let dot_product = 1.0 * 2.0 + 2.0 * 3.0; let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
477 }
478
479 #[test]
480 fn test_custom_laplacian_kernel() {
481 let kernel = CustomLaplacianKernel::new(0.5);
482 let x = vec![1.0, 2.0];
483 let y = vec![1.0, 2.0];
484
485 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
486
487 let y2 = vec![2.0, 4.0];
488 let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); let expected = (-0.5_f64 * l1_dist).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
491 }
492
493 #[test]
494 fn test_custom_kernel_sampler_basic() {
495 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
496 let kernel = CustomRBFKernel::new(0.1);
497
498 let sampler = CustomKernelSampler::new(kernel, 50);
499 let fitted = sampler.fit(&x, &()).unwrap();
500 let x_transformed = fitted.transform(&x).unwrap();
501
502 assert_eq!(x_transformed.shape(), &[3, 50]);
503
504 for val in x_transformed.iter() {
506 assert!(val.abs() <= 2.0); }
508 }
509
510 #[test]
511 fn test_custom_kernel_sampler_reproducibility() {
512 let x = array![[1.0, 2.0], [3.0, 4.0]];
513 let kernel1 = CustomRBFKernel::new(0.1);
514 let kernel2 = CustomRBFKernel::new(0.1);
515
516 let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
517 let fitted1 = sampler1.fit(&x, &()).unwrap();
518 let result1 = fitted1.transform(&x).unwrap();
519
520 let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
521 let fitted2 = sampler2.fit(&x, &()).unwrap();
522 let result2 = fitted2.transform(&x).unwrap();
523
524 for (a, b) in result1.iter().zip(result2.iter()) {
526 assert!((a - b).abs() < 1e-10);
527 }
528 }
529
530 #[test]
531 fn test_custom_kernel_sampler_different_kernels() {
532 let x = array![[1.0, 2.0], [3.0, 4.0]];
533
534 let rbf_kernel = CustomRBFKernel::new(0.1);
536 let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
537 let fitted_rbf = rbf_sampler.fit(&x, &()).unwrap();
538 let result_rbf = fitted_rbf.transform(&x).unwrap();
539 assert_eq!(result_rbf.shape(), &[2, 10]);
540
541 let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
542 let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
543 let fitted_poly = poly_sampler.fit(&x, &()).unwrap();
544 let result_poly = fitted_poly.transform(&x).unwrap();
545 assert_eq!(result_poly.shape(), &[2, 10]);
546
547 let lap_kernel = CustomLaplacianKernel::new(0.5);
548 let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
549 let fitted_lap = lap_sampler.fit(&x, &()).unwrap();
550 let result_lap = fitted_lap.transform(&x).unwrap();
551 assert_eq!(result_lap.shape(), &[2, 10]);
552 }
553
554 #[test]
555 fn test_exact_kernel_matrix_computation() {
556 let x = array![[1.0, 2.0], [3.0, 4.0]];
557 let y = array![[1.0, 2.0], [5.0, 6.0]];
558 let kernel = CustomRBFKernel::new(0.5);
559
560 let sampler = CustomKernelSampler::new(kernel.clone(), 10);
561 let fitted = sampler.fit(&x, &()).unwrap();
562 let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
563
564 assert_eq!(kernel_matrix.shape(), &[2, 2]);
565
566 assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
568
569 let x1 = vec![1.0, 2.0];
571 let y2 = vec![5.0, 6.0];
572 let expected = kernel.kernel(&x1, &y2);
573 assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
574 }
575
576 #[test]
577 fn test_custom_kernel_feature_mismatch() {
578 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
579 let x_test = array![[1.0, 2.0, 3.0]]; let kernel = CustomRBFKernel::new(0.1);
582 let sampler = CustomKernelSampler::new(kernel, 10);
583 let fitted = sampler.fit(&x_train, &()).unwrap();
584 let result = fitted.transform(&x_test);
585
586 assert!(result.is_err());
587 }
588
589 #[test]
590 fn test_custom_kernel_zero_components() {
591 let x = array![[1.0, 2.0]];
592 let kernel = CustomRBFKernel::new(0.1);
593 let sampler = CustomKernelSampler::new(kernel, 0);
594 let result = sampler.fit(&x, &());
595 assert!(result.is_err());
596 }
597}