sklears_kernel_approximation/
custom_kernel.rs

1//! Custom kernel random feature generation framework
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use sklears_core::{
7    error::{Result, SklearsError},
8    traits::{Estimator, Fit, Trained, Transform, Untrained},
9    types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, Rng, SeedableRng};
14/// Trait for defining custom kernel functions
15pub trait KernelFunction: Clone + Send + Sync {
16    /// Compute the kernel value between two vectors
17    fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
18
19    /// Get the characteristic function of the kernel's Fourier transform
20    /// This should return the Fourier transform of the kernel at frequency w
21    /// For most kernels, this is needed to generate appropriate random features
22    fn fourier_transform(&self, w: &[Float]) -> Float;
23
24    /// Sample random frequencies for Random Fourier Features
25    /// This method should sample frequencies according to the spectral measure
26    /// of the kernel (i.e., the Fourier transform of the kernel)
27    fn sample_frequencies(
28        &self,
29        n_features: usize,
30        n_components: usize,
31        rng: &mut RealStdRng,
32    ) -> Array2<Float>;
33
34    /// Get a description of the kernel
35    fn description(&self) -> String;
36}
37
38/// Custom RBF kernel with configurable parameters
39#[derive(Debug, Clone)]
40/// CustomRBFKernel
41pub struct CustomRBFKernel {
42    /// gamma
43    pub gamma: Float,
44    /// sigma
45    pub sigma: Float,
46}
47
48impl CustomRBFKernel {
49    pub fn new(gamma: Float) -> Self {
50        Self {
51            gamma,
52            sigma: (1.0 / (2.0 * gamma)).sqrt(),
53        }
54    }
55
56    pub fn from_sigma(sigma: Float) -> Self {
57        let gamma = 1.0 / (2.0 * sigma * sigma);
58        Self { gamma, sigma }
59    }
60}
61
62impl KernelFunction for CustomRBFKernel {
63    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
64        let dist_sq: Float = x
65            .iter()
66            .zip(y.iter())
67            .map(|(xi, yi)| (xi - yi).powi(2))
68            .sum();
69        (-self.gamma * dist_sq).exp()
70    }
71
72    fn fourier_transform(&self, w: &[Float]) -> Float {
73        let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
74        (-w_norm_sq / (4.0 * self.gamma)).exp()
75    }
76
77    fn sample_frequencies(
78        &self,
79        n_features: usize,
80        n_components: usize,
81        rng: &mut RealStdRng,
82    ) -> Array2<Float> {
83        let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
84        let mut weights = Array2::zeros((n_features, n_components));
85        for mut col in weights.columns_mut() {
86            for val in col.iter_mut() {
87                *val = normal.sample(rng);
88            }
89        }
90        weights
91    }
92
93    fn description(&self) -> String {
94        format!("Custom RBF kernel with gamma={}", self.gamma)
95    }
96}
97
98/// Custom polynomial kernel
99#[derive(Debug, Clone)]
100/// CustomPolynomialKernel
101pub struct CustomPolynomialKernel {
102    /// gamma
103    pub gamma: Float,
104    /// coef0
105    pub coef0: Float,
106    /// degree
107    pub degree: u32,
108}
109
110impl CustomPolynomialKernel {
111    pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
112        Self {
113            gamma,
114            coef0,
115            degree,
116        }
117    }
118}
119
120impl KernelFunction for CustomPolynomialKernel {
121    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
122        let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
123        (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
124    }
125
126    fn fourier_transform(&self, w: &[Float]) -> Float {
127        // Polynomial kernels don't have a simple Fourier transform
128        // We use an approximation based on the dominant frequency component
129        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
130        (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
131    }
132
133    fn sample_frequencies(
134        &self,
135        n_features: usize,
136        n_components: usize,
137        rng: &mut RealStdRng,
138    ) -> Array2<Float> {
139        // For polynomial kernels, we sample from a scaled normal distribution
140        let normal = RandNormal::new(0.0, self.gamma.sqrt()).unwrap();
141        let mut weights = Array2::zeros((n_features, n_components));
142        for mut col in weights.columns_mut() {
143            for val in col.iter_mut() {
144                *val = normal.sample(rng);
145            }
146        }
147        weights
148    }
149
150    fn description(&self) -> String {
151        format!(
152            "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
153            self.degree, self.gamma, self.coef0
154        )
155    }
156}
157
158/// Custom Laplacian kernel
159#[derive(Debug, Clone)]
160/// CustomLaplacianKernel
161pub struct CustomLaplacianKernel {
162    /// gamma
163    pub gamma: Float,
164}
165
166impl CustomLaplacianKernel {
167    pub fn new(gamma: Float) -> Self {
168        Self { gamma }
169    }
170}
171
172impl KernelFunction for CustomLaplacianKernel {
173    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
174        let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
175        (-self.gamma * l1_dist).exp()
176    }
177
178    fn fourier_transform(&self, w: &[Float]) -> Float {
179        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
180        self.gamma / (self.gamma + w_norm).powi(2)
181    }
182
183    fn sample_frequencies(
184        &self,
185        n_features: usize,
186        n_components: usize,
187        rng: &mut RealStdRng,
188    ) -> Array2<Float> {
189        use scirs2_core::random::Cauchy;
190        let cauchy = Cauchy::new(0.0, self.gamma).unwrap();
191        let mut weights = Array2::zeros((n_features, n_components));
192        for mut col in weights.columns_mut() {
193            for val in col.iter_mut() {
194                *val = cauchy.sample(rng);
195            }
196        }
197        weights
198    }
199
200    fn description(&self) -> String {
201        format!("Custom Laplacian kernel with gamma={}", self.gamma)
202    }
203}
204
205/// Custom exponential kernel (1D version of Laplacian)
206#[derive(Debug, Clone)]
207/// CustomExponentialKernel
208pub struct CustomExponentialKernel {
209    /// length_scale
210    pub length_scale: Float,
211}
212
213impl CustomExponentialKernel {
214    pub fn new(length_scale: Float) -> Self {
215        Self { length_scale }
216    }
217}
218
219impl KernelFunction for CustomExponentialKernel {
220    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
221        let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
222        (-dist / self.length_scale).exp()
223    }
224
225    fn fourier_transform(&self, w: &[Float]) -> Float {
226        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
227        2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
228    }
229
230    fn sample_frequencies(
231        &self,
232        n_features: usize,
233        n_components: usize,
234        rng: &mut RealStdRng,
235    ) -> Array2<Float> {
236        use scirs2_core::random::Cauchy;
237        let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).unwrap();
238        let mut weights = Array2::zeros((n_features, n_components));
239        for mut col in weights.columns_mut() {
240            for val in col.iter_mut() {
241                *val = cauchy.sample(rng);
242            }
243        }
244        weights
245    }
246
247    fn description(&self) -> String {
248        format!(
249            "Custom Exponential kernel with length_scale={}",
250            self.length_scale
251        )
252    }
253}
254
255/// Custom kernel random feature generator
256///
257/// Generates random Fourier features for any custom kernel function that implements
258/// the KernelFunction trait. This provides a flexible framework for kernel approximation
259/// with user-defined kernels.
260///
261/// # Parameters
262///
263/// * `kernel` - Custom kernel function implementing KernelFunction trait
264/// * `n_components` - Number of random features to generate (default: 100)
265/// * `random_state` - Random seed for reproducibility
266///
267/// # Examples
268///
269/// ```rust,ignore
270/// use sklears_kernel_approximation::custom_kernel::{CustomKernelSampler, CustomRBFKernel};
271/// use sklears_core::traits::{Transform, Fit, Untrained}
272/// use scirs2_core::ndarray::array;
273///
274/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
275/// let kernel = CustomRBFKernel::new(0.1);
276///
277/// let sampler = CustomKernelSampler::new(kernel, 100);
278/// let fitted_sampler = sampler.fit(&X, &()).unwrap();
279/// let X_transformed = fitted_sampler.transform(&X).unwrap();
280/// assert_eq!(X_transformed.shape(), &[3, 100]);
281/// ```
282#[derive(Debug, Clone)]
283/// CustomKernelSampler
284pub struct CustomKernelSampler<K, State = Untrained>
285where
286    K: KernelFunction,
287{
288    /// Custom kernel function
289    pub kernel: K,
290    /// Number of random features
291    pub n_components: usize,
292    /// Random seed
293    pub random_state: Option<u64>,
294
295    // Fitted attributes
296    random_weights_: Option<Array2<Float>>,
297    random_offset_: Option<Array1<Float>>,
298
299    _state: PhantomData<State>,
300}
301
302impl<K> CustomKernelSampler<K, Untrained>
303where
304    K: KernelFunction,
305{
306    /// Create a new custom kernel sampler
307    pub fn new(kernel: K, n_components: usize) -> Self {
308        Self {
309            kernel,
310            n_components,
311            random_state: None,
312            random_weights_: None,
313            random_offset_: None,
314            _state: PhantomData,
315        }
316    }
317
318    /// Set random state for reproducibility
319    pub fn random_state(mut self, seed: u64) -> Self {
320        self.random_state = Some(seed);
321        self
322    }
323}
324
325impl<K> Estimator for CustomKernelSampler<K, Untrained>
326where
327    K: KernelFunction,
328{
329    type Config = ();
330    type Error = SklearsError;
331    type Float = Float;
332
333    fn config(&self) -> &Self::Config {
334        &()
335    }
336}
337
338impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
339where
340    K: KernelFunction,
341{
342    type Fitted = CustomKernelSampler<K, Trained>;
343
344    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
345        let (_, n_features) = x.dim();
346
347        if self.n_components == 0 {
348            return Err(SklearsError::InvalidInput(
349                "n_components must be positive".to_string(),
350            ));
351        }
352
353        let mut rng = if let Some(seed) = self.random_state {
354            RealStdRng::seed_from_u64(seed)
355        } else {
356            RealStdRng::from_seed(thread_rng().gen())
357        };
358
359        // Sample random frequencies using the kernel's sampling method
360        let random_weights =
361            self.kernel
362                .sample_frequencies(n_features, self.n_components, &mut rng);
363
364        // Sample random offsets from Uniform(0, 2π)
365        let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
366        let mut random_offset = Array1::zeros(self.n_components);
367        for val in random_offset.iter_mut() {
368            *val = rng.sample(uniform);
369        }
370
371        Ok(CustomKernelSampler {
372            kernel: self.kernel,
373            n_components: self.n_components,
374            random_state: self.random_state,
375            random_weights_: Some(random_weights),
376            random_offset_: Some(random_offset),
377            _state: PhantomData,
378        })
379    }
380}
381
382impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
383where
384    K: KernelFunction,
385{
386    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
387        let (n_samples, n_features) = x.dim();
388        let weights = self.random_weights_.as_ref().unwrap();
389        let offset = self.random_offset_.as_ref().unwrap();
390
391        if n_features != weights.nrows() {
392            return Err(SklearsError::InvalidInput(format!(
393                "X has {} features, but CustomKernelSampler was fitted with {} features",
394                n_features,
395                weights.nrows()
396            )));
397        }
398
399        // Compute projection: X @ weights + offset
400        let projection = x.dot(weights) + &offset.view().insert_axis(Axis(0));
401
402        // Apply cosine and normalize: sqrt(2/n_components) * cos(projection)
403        let normalization = (2.0 / self.n_components as Float).sqrt();
404        let result = projection.mapv(|v| normalization * v.cos());
405
406        Ok(result)
407    }
408}
409
410impl<K> CustomKernelSampler<K, Trained>
411where
412    K: KernelFunction,
413{
414    /// Get the random weights
415    pub fn random_weights(&self) -> &Array2<Float> {
416        self.random_weights_.as_ref().unwrap()
417    }
418
419    /// Get the random offset
420    pub fn random_offset(&self) -> &Array1<Float> {
421        self.random_offset_.as_ref().unwrap()
422    }
423
424    /// Get the kernel description
425    pub fn kernel_description(&self) -> String {
426        self.kernel.description()
427    }
428
429    /// Compute exact kernel matrix for comparison/evaluation
430    pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
431        let (n_x, _) = x.dim();
432        let (n_y, _) = y.dim();
433        let mut kernel_matrix = Array2::zeros((n_x, n_y));
434
435        for i in 0..n_x {
436            for j in 0..n_y {
437                let x_row = x.row(i).to_vec();
438                let y_row = y.row(j).to_vec();
439                kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
440            }
441        }
442
443        kernel_matrix
444    }
445}
446
447#[allow(non_snake_case)]
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use approx::assert_abs_diff_eq;
452    use scirs2_core::ndarray::array;
453
454    #[test]
455    fn test_custom_rbf_kernel() {
456        let kernel = CustomRBFKernel::new(0.5);
457        let x = vec![1.0, 2.0];
458        let y = vec![1.0, 2.0];
459
460        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
461
462        let y2 = vec![2.0, 3.0];
463        let expected = (-0.5_f64 * 2.0).exp(); // dist_sq = (1-2)² + (2-3)² = 2
464        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
465    }
466
467    #[test]
468    fn test_custom_polynomial_kernel() {
469        let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
470        let x = vec![1.0, 2.0];
471        let y = vec![2.0, 3.0];
472
473        let dot_product = 1.0 * 2.0 + 2.0 * 3.0; // = 8
474        let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); // = 9² = 81
475        assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
476    }
477
478    #[test]
479    fn test_custom_laplacian_kernel() {
480        let kernel = CustomLaplacianKernel::new(0.5);
481        let x = vec![1.0, 2.0];
482        let y = vec![1.0, 2.0];
483
484        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
485
486        let y2 = vec![2.0, 4.0];
487        let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); // = 1 + 2 = 3
488        let expected = (-0.5_f64 * l1_dist).exp(); // = exp(-1.5)
489        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
490    }
491
492    #[test]
493    fn test_custom_kernel_sampler_basic() {
494        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
495        let kernel = CustomRBFKernel::new(0.1);
496
497        let sampler = CustomKernelSampler::new(kernel, 50);
498        let fitted = sampler.fit(&x, &()).unwrap();
499        let x_transformed = fitted.transform(&x).unwrap();
500
501        assert_eq!(x_transformed.shape(), &[3, 50]);
502
503        // Check that values are in reasonable range for cosine function
504        for val in x_transformed.iter() {
505            assert!(val.abs() <= 2.0); // sqrt(2) * 1 is the max possible value
506        }
507    }
508
509    #[test]
510    fn test_custom_kernel_sampler_reproducibility() {
511        let x = array![[1.0, 2.0], [3.0, 4.0]];
512        let kernel1 = CustomRBFKernel::new(0.1);
513        let kernel2 = CustomRBFKernel::new(0.1);
514
515        let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
516        let fitted1 = sampler1.fit(&x, &()).unwrap();
517        let result1 = fitted1.transform(&x).unwrap();
518
519        let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
520        let fitted2 = sampler2.fit(&x, &()).unwrap();
521        let result2 = fitted2.transform(&x).unwrap();
522
523        // Results should be identical with same random state
524        for (a, b) in result1.iter().zip(result2.iter()) {
525            assert!((a - b).abs() < 1e-10);
526        }
527    }
528
529    #[test]
530    fn test_custom_kernel_sampler_different_kernels() {
531        let x = array![[1.0, 2.0], [3.0, 4.0]];
532
533        // Test with different kernel types
534        let rbf_kernel = CustomRBFKernel::new(0.1);
535        let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
536        let fitted_rbf = rbf_sampler.fit(&x, &()).unwrap();
537        let result_rbf = fitted_rbf.transform(&x).unwrap();
538        assert_eq!(result_rbf.shape(), &[2, 10]);
539
540        let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
541        let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
542        let fitted_poly = poly_sampler.fit(&x, &()).unwrap();
543        let result_poly = fitted_poly.transform(&x).unwrap();
544        assert_eq!(result_poly.shape(), &[2, 10]);
545
546        let lap_kernel = CustomLaplacianKernel::new(0.5);
547        let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
548        let fitted_lap = lap_sampler.fit(&x, &()).unwrap();
549        let result_lap = fitted_lap.transform(&x).unwrap();
550        assert_eq!(result_lap.shape(), &[2, 10]);
551    }
552
553    #[test]
554    fn test_exact_kernel_matrix_computation() {
555        let x = array![[1.0, 2.0], [3.0, 4.0]];
556        let y = array![[1.0, 2.0], [5.0, 6.0]];
557        let kernel = CustomRBFKernel::new(0.5);
558
559        let sampler = CustomKernelSampler::new(kernel.clone(), 10);
560        let fitted = sampler.fit(&x, &()).unwrap();
561        let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
562
563        assert_eq!(kernel_matrix.shape(), &[2, 2]);
564
565        // Check diagonal elements (should be 1.0 for RBF kernel with same points)
566        assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
567
568        // Manual verification for one element
569        let x1 = vec![1.0, 2.0];
570        let y2 = vec![5.0, 6.0];
571        let expected = kernel.kernel(&x1, &y2);
572        assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
573    }
574
575    #[test]
576    fn test_custom_kernel_feature_mismatch() {
577        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
578        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
579
580        let kernel = CustomRBFKernel::new(0.1);
581        let sampler = CustomKernelSampler::new(kernel, 10);
582        let fitted = sampler.fit(&x_train, &()).unwrap();
583        let result = fitted.transform(&x_test);
584
585        assert!(result.is_err());
586    }
587
588    #[test]
589    fn test_custom_kernel_zero_components() {
590        let x = array![[1.0, 2.0]];
591        let kernel = CustomRBFKernel::new(0.1);
592        let sampler = CustomKernelSampler::new(kernel, 0);
593        let result = sampler.fit(&x, &());
594        assert!(result.is_err());
595    }
596}