1use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use scirs2_core::random::RngExt;
7use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Trained, Transform, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14use scirs2_core::random::{thread_rng, SeedableRng};
15pub trait KernelFunction: Clone + Send + Sync {
17 fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
19
20 fn fourier_transform(&self, w: &[Float]) -> Float;
24
25 fn sample_frequencies(
29 &self,
30 n_features: usize,
31 n_components: usize,
32 rng: &mut RealStdRng,
33 ) -> Array2<Float>;
34
35 fn description(&self) -> String;
37}
38
39#[derive(Debug, Clone)]
41pub struct CustomRBFKernel {
43 pub gamma: Float,
45 pub sigma: Float,
47}
48
49impl CustomRBFKernel {
50 pub fn new(gamma: Float) -> Self {
51 Self {
52 gamma,
53 sigma: (1.0 / (2.0 * gamma)).sqrt(),
54 }
55 }
56
57 pub fn from_sigma(sigma: Float) -> Self {
58 let gamma = 1.0 / (2.0 * sigma * sigma);
59 Self { gamma, sigma }
60 }
61}
62
63impl KernelFunction for CustomRBFKernel {
64 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
65 let dist_sq: Float = x
66 .iter()
67 .zip(y.iter())
68 .map(|(xi, yi)| (xi - yi).powi(2))
69 .sum();
70 (-self.gamma * dist_sq).exp()
71 }
72
73 fn fourier_transform(&self, w: &[Float]) -> Float {
74 let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
75 (-w_norm_sq / (4.0 * self.gamma)).exp()
76 }
77
78 fn sample_frequencies(
79 &self,
80 n_features: usize,
81 n_components: usize,
82 rng: &mut RealStdRng,
83 ) -> Array2<Float> {
84 let normal =
85 RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).expect("operation should succeed");
86 let mut weights = Array2::zeros((n_features, n_components));
87 for mut col in weights.columns_mut() {
88 for val in col.iter_mut() {
89 *val = normal.sample(rng);
90 }
91 }
92 weights
93 }
94
95 fn description(&self) -> String {
96 format!("Custom RBF kernel with gamma={}", self.gamma)
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct CustomPolynomialKernel {
104 pub gamma: Float,
106 pub coef0: Float,
108 pub degree: u32,
110}
111
112impl CustomPolynomialKernel {
113 pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
114 Self {
115 gamma,
116 coef0,
117 degree,
118 }
119 }
120}
121
122impl KernelFunction for CustomPolynomialKernel {
123 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
124 let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
125 (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
126 }
127
128 fn fourier_transform(&self, w: &[Float]) -> Float {
129 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
132 (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
133 }
134
135 fn sample_frequencies(
136 &self,
137 n_features: usize,
138 n_components: usize,
139 rng: &mut RealStdRng,
140 ) -> Array2<Float> {
141 let normal = RandNormal::new(0.0, self.gamma.sqrt()).expect("operation should succeed");
143 let mut weights = Array2::zeros((n_features, n_components));
144 for mut col in weights.columns_mut() {
145 for val in col.iter_mut() {
146 *val = normal.sample(rng);
147 }
148 }
149 weights
150 }
151
152 fn description(&self) -> String {
153 format!(
154 "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
155 self.degree, self.gamma, self.coef0
156 )
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct CustomLaplacianKernel {
164 pub gamma: Float,
166}
167
168impl CustomLaplacianKernel {
169 pub fn new(gamma: Float) -> Self {
170 Self { gamma }
171 }
172}
173
174impl KernelFunction for CustomLaplacianKernel {
175 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
176 let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
177 (-self.gamma * l1_dist).exp()
178 }
179
180 fn fourier_transform(&self, w: &[Float]) -> Float {
181 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
182 self.gamma / (self.gamma + w_norm).powi(2)
183 }
184
185 fn sample_frequencies(
186 &self,
187 n_features: usize,
188 n_components: usize,
189 rng: &mut RealStdRng,
190 ) -> Array2<Float> {
191 use scirs2_core::random::Cauchy;
192 let cauchy = Cauchy::new(0.0, self.gamma).expect("operation should succeed");
193 let mut weights = Array2::zeros((n_features, n_components));
194 for mut col in weights.columns_mut() {
195 for val in col.iter_mut() {
196 *val = cauchy.sample(rng);
197 }
198 }
199 weights
200 }
201
202 fn description(&self) -> String {
203 format!("Custom Laplacian kernel with gamma={}", self.gamma)
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct CustomExponentialKernel {
211 pub length_scale: Float,
213}
214
215impl CustomExponentialKernel {
216 pub fn new(length_scale: Float) -> Self {
217 Self { length_scale }
218 }
219}
220
221impl KernelFunction for CustomExponentialKernel {
222 fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
223 let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
224 (-dist / self.length_scale).exp()
225 }
226
227 fn fourier_transform(&self, w: &[Float]) -> Float {
228 let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
229 2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
230 }
231
232 fn sample_frequencies(
233 &self,
234 n_features: usize,
235 n_components: usize,
236 rng: &mut RealStdRng,
237 ) -> Array2<Float> {
238 use scirs2_core::random::Cauchy;
239 let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).expect("operation should succeed");
240 let mut weights = Array2::zeros((n_features, n_components));
241 for mut col in weights.columns_mut() {
242 for val in col.iter_mut() {
243 *val = cauchy.sample(rng);
244 }
245 }
246 weights
247 }
248
249 fn description(&self) -> String {
250 format!(
251 "Custom Exponential kernel with length_scale={}",
252 self.length_scale
253 )
254 }
255}
256
257#[derive(Debug, Clone)]
285pub struct CustomKernelSampler<K, State = Untrained>
287where
288 K: KernelFunction,
289{
290 pub kernel: K,
292 pub n_components: usize,
294 pub random_state: Option<u64>,
296
297 random_weights_: Option<Array2<Float>>,
299 random_offset_: Option<Array1<Float>>,
300
301 _state: PhantomData<State>,
302}
303
304impl<K> CustomKernelSampler<K, Untrained>
305where
306 K: KernelFunction,
307{
308 pub fn new(kernel: K, n_components: usize) -> Self {
310 Self {
311 kernel,
312 n_components,
313 random_state: None,
314 random_weights_: None,
315 random_offset_: None,
316 _state: PhantomData,
317 }
318 }
319
320 pub fn random_state(mut self, seed: u64) -> Self {
322 self.random_state = Some(seed);
323 self
324 }
325}
326
327impl<K> Estimator for CustomKernelSampler<K, Untrained>
328where
329 K: KernelFunction,
330{
331 type Config = ();
332 type Error = SklearsError;
333 type Float = Float;
334
335 fn config(&self) -> &Self::Config {
336 &()
337 }
338}
339
340impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
341where
342 K: KernelFunction,
343{
344 type Fitted = CustomKernelSampler<K, Trained>;
345
346 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
347 let (_, n_features) = x.dim();
348
349 if self.n_components == 0 {
350 return Err(SklearsError::InvalidInput(
351 "n_components must be positive".to_string(),
352 ));
353 }
354
355 let mut rng = if let Some(seed) = self.random_state {
356 RealStdRng::seed_from_u64(seed)
357 } else {
358 RealStdRng::from_seed(thread_rng().random())
359 };
360
361 let random_weights =
363 self.kernel
364 .sample_frequencies(n_features, self.n_components, &mut rng);
365
366 let uniform =
368 RandUniform::new(0.0, 2.0 * std::f64::consts::PI).expect("operation should succeed");
369 let mut random_offset = Array1::zeros(self.n_components);
370 for val in random_offset.iter_mut() {
371 *val = rng.sample(uniform);
372 }
373
374 Ok(CustomKernelSampler {
375 kernel: self.kernel,
376 n_components: self.n_components,
377 random_state: self.random_state,
378 random_weights_: Some(random_weights),
379 random_offset_: Some(random_offset),
380 _state: PhantomData,
381 })
382 }
383}
384
385impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
386where
387 K: KernelFunction,
388{
389 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
390 let (_n_samples, n_features) = x.dim();
391 let weights = self
392 .random_weights_
393 .as_ref()
394 .expect("operation should succeed");
395 let offset = self
396 .random_offset_
397 .as_ref()
398 .expect("operation should succeed");
399
400 if n_features != weights.nrows() {
401 return Err(SklearsError::InvalidInput(format!(
402 "X has {} features, but CustomKernelSampler was fitted with {} features",
403 n_features,
404 weights.nrows()
405 )));
406 }
407
408 let projection = x.dot(weights) + offset.view().insert_axis(Axis(0));
410
411 let normalization = (2.0 / self.n_components as Float).sqrt();
413 let result = projection.mapv(|v| normalization * v.cos());
414
415 Ok(result)
416 }
417}
418
419impl<K> CustomKernelSampler<K, Trained>
420where
421 K: KernelFunction,
422{
423 pub fn random_weights(&self) -> &Array2<Float> {
425 self.random_weights_
426 .as_ref()
427 .expect("operation should succeed")
428 }
429
430 pub fn random_offset(&self) -> &Array1<Float> {
432 self.random_offset_
433 .as_ref()
434 .expect("operation should succeed")
435 }
436
437 pub fn kernel_description(&self) -> String {
439 self.kernel.description()
440 }
441
442 pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
444 let (n_x, _) = x.dim();
445 let (n_y, _) = y.dim();
446 let mut kernel_matrix = Array2::zeros((n_x, n_y));
447
448 for i in 0..n_x {
449 for j in 0..n_y {
450 let x_row = x.row(i).to_vec();
451 let y_row = y.row(j).to_vec();
452 kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
453 }
454 }
455
456 kernel_matrix
457 }
458}
459
460#[allow(non_snake_case)]
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use approx::assert_abs_diff_eq;
465 use scirs2_core::ndarray::array;
466
467 #[test]
468 fn test_custom_rbf_kernel() {
469 let kernel = CustomRBFKernel::new(0.5);
470 let x = vec![1.0, 2.0];
471 let y = vec![1.0, 2.0];
472
473 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
474
475 let y2 = vec![2.0, 3.0];
476 let expected = (-0.5_f64 * 2.0).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
478 }
479
480 #[test]
481 fn test_custom_polynomial_kernel() {
482 let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
483 let x = vec![1.0, 2.0];
484 let y = vec![2.0, 3.0];
485
486 let dot_product = 1.0 * 2.0 + 2.0 * 3.0; let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
489 }
490
491 #[test]
492 fn test_custom_laplacian_kernel() {
493 let kernel = CustomLaplacianKernel::new(0.5);
494 let x = vec![1.0, 2.0];
495 let y = vec![1.0, 2.0];
496
497 assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
498
499 let y2 = vec![2.0, 4.0];
500 let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); let expected = (-0.5_f64 * l1_dist).exp(); assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
503 }
504
505 #[test]
506 fn test_custom_kernel_sampler_basic() {
507 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
508 let kernel = CustomRBFKernel::new(0.1);
509
510 let sampler = CustomKernelSampler::new(kernel, 50);
511 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
512 let x_transformed = fitted.transform(&x).expect("operation should succeed");
513
514 assert_eq!(x_transformed.shape(), &[3, 50]);
515
516 for val in x_transformed.iter() {
518 assert!(val.abs() <= 2.0); }
520 }
521
522 #[test]
523 fn test_custom_kernel_sampler_reproducibility() {
524 let x = array![[1.0, 2.0], [3.0, 4.0]];
525 let kernel1 = CustomRBFKernel::new(0.1);
526 let kernel2 = CustomRBFKernel::new(0.1);
527
528 let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
529 let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
530 let result1 = fitted1.transform(&x).expect("operation should succeed");
531
532 let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
533 let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
534 let result2 = fitted2.transform(&x).expect("operation should succeed");
535
536 for (a, b) in result1.iter().zip(result2.iter()) {
538 assert!((a - b).abs() < 1e-10);
539 }
540 }
541
542 #[test]
543 fn test_custom_kernel_sampler_different_kernels() {
544 let x = array![[1.0, 2.0], [3.0, 4.0]];
545
546 let rbf_kernel = CustomRBFKernel::new(0.1);
548 let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
549 let fitted_rbf = rbf_sampler.fit(&x, &()).expect("operation should succeed");
550 let result_rbf = fitted_rbf.transform(&x).expect("operation should succeed");
551 assert_eq!(result_rbf.shape(), &[2, 10]);
552
553 let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
554 let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
555 let fitted_poly = poly_sampler.fit(&x, &()).expect("operation should succeed");
556 let result_poly = fitted_poly.transform(&x).expect("operation should succeed");
557 assert_eq!(result_poly.shape(), &[2, 10]);
558
559 let lap_kernel = CustomLaplacianKernel::new(0.5);
560 let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
561 let fitted_lap = lap_sampler.fit(&x, &()).expect("operation should succeed");
562 let result_lap = fitted_lap.transform(&x).expect("operation should succeed");
563 assert_eq!(result_lap.shape(), &[2, 10]);
564 }
565
566 #[test]
567 fn test_exact_kernel_matrix_computation() {
568 let x = array![[1.0, 2.0], [3.0, 4.0]];
569 let y = array![[1.0, 2.0], [5.0, 6.0]];
570 let kernel = CustomRBFKernel::new(0.5);
571
572 let sampler = CustomKernelSampler::new(kernel.clone(), 10);
573 let fitted = sampler.fit(&x, &()).expect("operation should succeed");
574 let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
575
576 assert_eq!(kernel_matrix.shape(), &[2, 2]);
577
578 assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
580
581 let x1 = vec![1.0, 2.0];
583 let y2 = vec![5.0, 6.0];
584 let expected = kernel.kernel(&x1, &y2);
585 assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
586 }
587
588 #[test]
589 fn test_custom_kernel_feature_mismatch() {
590 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
591 let x_test = array![[1.0, 2.0, 3.0]]; let kernel = CustomRBFKernel::new(0.1);
594 let sampler = CustomKernelSampler::new(kernel, 10);
595 let fitted = sampler
596 .fit(&x_train, &())
597 .expect("operation should succeed");
598 let result = fitted.transform(&x_test);
599
600 assert!(result.is_err());
601 }
602
603 #[test]
604 fn test_custom_kernel_zero_components() {
605 let x = array![[1.0, 2.0]];
606 let kernel = CustomRBFKernel::new(0.1);
607 let sampler = CustomKernelSampler::new(kernel, 0);
608 let result = sampler.fit(&x, &());
609 assert!(result.is_err());
610 }
611}