1use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use sklears_core::{
7 error::{Result, SklearsError},
8 traits::{Estimator, Fit, Trained, Transform, Untrained},
9 types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, Rng, SeedableRng};
14pub trait KernelFunction: Clone + Send + Sync {
16 fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
18
19 fn fourier_transform(&self, w: &[Float]) -> Float;
23
24 fn sample_frequencies(
28 &self,
29 n_features: usize,
30 n_components: usize,
31 rng: &mut RealStdRng,
32 ) -> Array2<Float>;
33
34 fn description(&self) -> String;
36}
37
38#[derive(Debug, Clone)]
40pub struct CustomRBFKernel {
42 pub gamma: Float,
44 pub sigma: Float,
46}
47
48impl CustomRBFKernel {
49 pub fn new(gamma: Float) -> Self {
50 Self {
51 gamma,
52 sigma: (1.0 / (2.0 * gamma)).sqrt(),
53 }
54 }
55
56 pub fn from_sigma(sigma: Float) -> Self {
57 let gamma = 1.0 / (2.0 * sigma * sigma);
58 Self { gamma, sigma }
59 }
60}
61
62impl KernelFunction for CustomRBFKernel {
63 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
64 let dist_sq: Float = x
65 .iter()
66 .zip(y.iter())
67 .map(|(xi, yi)| (xi - yi).powi(2))
68 .sum();
69 (-self.gamma * dist_sq).exp()
70 }
71
72 fn fourier_transform(&self, w: &[Float]) -> Float {
73 let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
74 (-w_norm_sq / (4.0 * self.gamma)).exp()
75 }
76
77 fn sample_frequencies(
78 &self,
79 n_features: usize,
80 n_components: usize,
81 rng: &mut RealStdRng,
82 ) -> Array2<Float> {
83 let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
84 let mut weights = Array2::zeros((n_features, n_components));
85 for mut col in weights.columns_mut() {
86 for val in col.iter_mut() {
87 *val = normal.sample(rng);
88 }
89 }
90 weights
91 }
92
93 fn description(&self) -> String {
94 format!("Custom RBF kernel with gamma={}", self.gamma)
95 }
96}
97
98#[derive(Debug, Clone)]
100pub struct CustomPolynomialKernel {
102 pub gamma: Float,
104 pub coef0: Float,
106 pub degree: u32,
108}
109
110impl CustomPolynomialKernel {
111 pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
112 Self {
113 gamma,
114 coef0,
115 degree,
116 }
117 }
118}
119
120impl KernelFunction for CustomPolynomialKernel {
121 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
122 let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
123 (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
124 }
125
126 fn fourier_transform(&self, w: &[Float]) -> Float {
127 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
130 (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
131 }
132
133 fn sample_frequencies(
134 &self,
135 n_features: usize,
136 n_components: usize,
137 rng: &mut RealStdRng,
138 ) -> Array2<Float> {
139 let normal = RandNormal::new(0.0, self.gamma.sqrt()).unwrap();
141 let mut weights = Array2::zeros((n_features, n_components));
142 for mut col in weights.columns_mut() {
143 for val in col.iter_mut() {
144 *val = normal.sample(rng);
145 }
146 }
147 weights
148 }
149
150 fn description(&self) -> String {
151 format!(
152 "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
153 self.degree, self.gamma, self.coef0
154 )
155 }
156}
157
158#[derive(Debug, Clone)]
160pub struct CustomLaplacianKernel {
162 pub gamma: Float,
164}
165
166impl CustomLaplacianKernel {
167 pub fn new(gamma: Float) -> Self {
168 Self { gamma }
169 }
170}
171
172impl KernelFunction for CustomLaplacianKernel {
173 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
174 let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
175 (-self.gamma * l1_dist).exp()
176 }
177
178 fn fourier_transform(&self, w: &[Float]) -> Float {
179 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
180 self.gamma / (self.gamma + w_norm).powi(2)
181 }
182
183 fn sample_frequencies(
184 &self,
185 n_features: usize,
186 n_components: usize,
187 rng: &mut RealStdRng,
188 ) -> Array2<Float> {
189 use scirs2_core::random::Cauchy;
190 let cauchy = Cauchy::new(0.0, self.gamma).unwrap();
191 let mut weights = Array2::zeros((n_features, n_components));
192 for mut col in weights.columns_mut() {
193 for val in col.iter_mut() {
194 *val = cauchy.sample(rng);
195 }
196 }
197 weights
198 }
199
200 fn description(&self) -> String {
201 format!("Custom Laplacian kernel with gamma={}", self.gamma)
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct CustomExponentialKernel {
209 pub length_scale: Float,
211}
212
213impl CustomExponentialKernel {
214 pub fn new(length_scale: Float) -> Self {
215 Self { length_scale }
216 }
217}
218
219impl KernelFunction for CustomExponentialKernel {
220 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
221 let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
222 (-dist / self.length_scale).exp()
223 }
224
225 fn fourier_transform(&self, w: &[Float]) -> Float {
226 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
227 2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
228 }
229
230 fn sample_frequencies(
231 &self,
232 n_features: usize,
233 n_components: usize,
234 rng: &mut RealStdRng,
235 ) -> Array2<Float> {
236 use scirs2_core::random::Cauchy;
237 let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).unwrap();
238 let mut weights = Array2::zeros((n_features, n_components));
239 for mut col in weights.columns_mut() {
240 for val in col.iter_mut() {
241 *val = cauchy.sample(rng);
242 }
243 }
244 weights
245 }
246
247 fn description(&self) -> String {
248 format!(
249 "Custom Exponential kernel with length_scale={}",
250 self.length_scale
251 )
252 }
253}
254
255#[derive(Debug, Clone)]
283pub struct CustomKernelSampler<K, State = Untrained>
285where
286 K: KernelFunction,
287{
288 pub kernel: K,
290 pub n_components: usize,
292 pub random_state: Option<u64>,
294
295 random_weights_: Option<Array2<Float>>,
297 random_offset_: Option<Array1<Float>>,
298
299 _state: PhantomData<State>,
300}
301
302impl<K> CustomKernelSampler<K, Untrained>
303where
304 K: KernelFunction,
305{
306 pub fn new(kernel: K, n_components: usize) -> Self {
308 Self {
309 kernel,
310 n_components,
311 random_state: None,
312 random_weights_: None,
313 random_offset_: None,
314 _state: PhantomData,
315 }
316 }
317
318 pub fn random_state(mut self, seed: u64) -> Self {
320 self.random_state = Some(seed);
321 self
322 }
323}
324
325impl<K> Estimator for CustomKernelSampler<K, Untrained>
326where
327 K: KernelFunction,
328{
329 type Config = ();
330 type Error = SklearsError;
331 type Float = Float;
332
333 fn config(&self) -> &Self::Config {
334 &()
335 }
336}
337
338impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
339where
340 K: KernelFunction,
341{
342 type Fitted = CustomKernelSampler<K, Trained>;
343
344 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
345 let (_, n_features) = x.dim();
346
347 if self.n_components == 0 {
348 return Err(SklearsError::InvalidInput(
349 "n_components must be positive".to_string(),
350 ));
351 }
352
353 let mut rng = if let Some(seed) = self.random_state {
354 RealStdRng::seed_from_u64(seed)
355 } else {
356 RealStdRng::from_seed(thread_rng().gen())
357 };
358
359 let random_weights =
361 self.kernel
362 .sample_frequencies(n_features, self.n_components, &mut rng);
363
364 let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
366 let mut random_offset = Array1::zeros(self.n_components);
367 for val in random_offset.iter_mut() {
368 *val = rng.sample(uniform);
369 }
370
371 Ok(CustomKernelSampler {
372 kernel: self.kernel,
373 n_components: self.n_components,
374 random_state: self.random_state,
375 random_weights_: Some(random_weights),
376 random_offset_: Some(random_offset),
377 _state: PhantomData,
378 })
379 }
380}
381
382impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
383where
384 K: KernelFunction,
385{
386 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
387 let (n_samples, n_features) = x.dim();
388 let weights = self.random_weights_.as_ref().unwrap();
389 let offset = self.random_offset_.as_ref().unwrap();
390
391 if n_features != weights.nrows() {
392 return Err(SklearsError::InvalidInput(format!(
393 "X has {} features, but CustomKernelSampler was fitted with {} features",
394 n_features,
395 weights.nrows()
396 )));
397 }
398
399 let projection = x.dot(weights) + &offset.view().insert_axis(Axis(0));
401
402 let normalization = (2.0 / self.n_components as Float).sqrt();
404 let result = projection.mapv(|v| normalization * v.cos());
405
406 Ok(result)
407 }
408}
409
410impl<K> CustomKernelSampler<K, Trained>
411where
412 K: KernelFunction,
413{
414 pub fn random_weights(&self) -> &Array2<Float> {
416 self.random_weights_.as_ref().unwrap()
417 }
418
419 pub fn random_offset(&self) -> &Array1<Float> {
421 self.random_offset_.as_ref().unwrap()
422 }
423
424 pub fn kernel_description(&self) -> String {
426 self.kernel.description()
427 }
428
429 pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
431 let (n_x, _) = x.dim();
432 let (n_y, _) = y.dim();
433 let mut kernel_matrix = Array2::zeros((n_x, n_y));
434
435 for i in 0..n_x {
436 for j in 0..n_y {
437 let x_row = x.row(i).to_vec();
438 let y_row = y.row(j).to_vec();
439 kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
440 }
441 }
442
443 kernel_matrix
444 }
445}
446
447#[allow(non_snake_case)]
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use approx::assert_abs_diff_eq;
452 use scirs2_core::ndarray::array;
453
454 #[test]
455 fn test_custom_rbf_kernel() {
456 let kernel = CustomRBFKernel::new(0.5);
457 let x = vec![1.0, 2.0];
458 let y = vec![1.0, 2.0];
459
460 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
461
462 let y2 = vec![2.0, 3.0];
463 let expected = (-0.5_f64 * 2.0).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
465 }
466
467 #[test]
468 fn test_custom_polynomial_kernel() {
469 let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
470 let x = vec![1.0, 2.0];
471 let y = vec![2.0, 3.0];
472
473 let dot_product = 1.0 * 2.0 + 2.0 * 3.0; let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
476 }
477
478 #[test]
479 fn test_custom_laplacian_kernel() {
480 let kernel = CustomLaplacianKernel::new(0.5);
481 let x = vec![1.0, 2.0];
482 let y = vec![1.0, 2.0];
483
484 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
485
486 let y2 = vec![2.0, 4.0];
487 let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); let expected = (-0.5_f64 * l1_dist).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
490 }
491
492 #[test]
493 fn test_custom_kernel_sampler_basic() {
494 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
495 let kernel = CustomRBFKernel::new(0.1);
496
497 let sampler = CustomKernelSampler::new(kernel, 50);
498 let fitted = sampler.fit(&x, &()).unwrap();
499 let x_transformed = fitted.transform(&x).unwrap();
500
501 assert_eq!(x_transformed.shape(), &[3, 50]);
502
503 for val in x_transformed.iter() {
505 assert!(val.abs() <= 2.0); }
507 }
508
509 #[test]
510 fn test_custom_kernel_sampler_reproducibility() {
511 let x = array![[1.0, 2.0], [3.0, 4.0]];
512 let kernel1 = CustomRBFKernel::new(0.1);
513 let kernel2 = CustomRBFKernel::new(0.1);
514
515 let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
516 let fitted1 = sampler1.fit(&x, &()).unwrap();
517 let result1 = fitted1.transform(&x).unwrap();
518
519 let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
520 let fitted2 = sampler2.fit(&x, &()).unwrap();
521 let result2 = fitted2.transform(&x).unwrap();
522
523 for (a, b) in result1.iter().zip(result2.iter()) {
525 assert!((a - b).abs() < 1e-10);
526 }
527 }
528
529 #[test]
530 fn test_custom_kernel_sampler_different_kernels() {
531 let x = array![[1.0, 2.0], [3.0, 4.0]];
532
533 let rbf_kernel = CustomRBFKernel::new(0.1);
535 let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
536 let fitted_rbf = rbf_sampler.fit(&x, &()).unwrap();
537 let result_rbf = fitted_rbf.transform(&x).unwrap();
538 assert_eq!(result_rbf.shape(), &[2, 10]);
539
540 let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
541 let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
542 let fitted_poly = poly_sampler.fit(&x, &()).unwrap();
543 let result_poly = fitted_poly.transform(&x).unwrap();
544 assert_eq!(result_poly.shape(), &[2, 10]);
545
546 let lap_kernel = CustomLaplacianKernel::new(0.5);
547 let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
548 let fitted_lap = lap_sampler.fit(&x, &()).unwrap();
549 let result_lap = fitted_lap.transform(&x).unwrap();
550 assert_eq!(result_lap.shape(), &[2, 10]);
551 }
552
553 #[test]
554 fn test_exact_kernel_matrix_computation() {
555 let x = array![[1.0, 2.0], [3.0, 4.0]];
556 let y = array![[1.0, 2.0], [5.0, 6.0]];
557 let kernel = CustomRBFKernel::new(0.5);
558
559 let sampler = CustomKernelSampler::new(kernel.clone(), 10);
560 let fitted = sampler.fit(&x, &()).unwrap();
561 let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
562
563 assert_eq!(kernel_matrix.shape(), &[2, 2]);
564
565 assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
567
568 let x1 = vec![1.0, 2.0];
570 let y2 = vec![5.0, 6.0];
571 let expected = kernel.kernel(&x1, &y2);
572 assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
573 }
574
575 #[test]
576 fn test_custom_kernel_feature_mismatch() {
577 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
578 let x_test = array![[1.0, 2.0, 3.0]]; let kernel = CustomRBFKernel::new(0.1);
581 let sampler = CustomKernelSampler::new(kernel, 10);
582 let fitted = sampler.fit(&x_train, &()).unwrap();
583 let result = fitted.transform(&x_test);
584
585 assert!(result.is_err());
586 }
587
588 #[test]
589 fn test_custom_kernel_zero_components() {
590 let x = array![[1.0, 2.0]];
591 let kernel = CustomRBFKernel::new(0.1);
592 let sampler = CustomKernelSampler::new(kernel, 0);
593 let result = sampler.fit(&x, &());
594 assert!(result.is_err());
595 }
596}