sklears_kernel_approximation/
custom_kernel.rs

1//! Custom kernel random feature generation framework
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use scirs2_core::random::Rng;
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14use scirs2_core::random::{thread_rng, SeedableRng};
15/// Trait for defining custom kernel functions
16pub trait KernelFunction: Clone + Send + Sync {
17    /// Compute the kernel value between two vectors
18    fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
19
20    /// Get the characteristic function of the kernel's Fourier transform
21    /// This should return the Fourier transform of the kernel at frequency w
22    /// For most kernels, this is needed to generate appropriate random features
23    fn fourier_transform(&self, w: &[Float]) -> Float;
24
25    /// Sample random frequencies for Random Fourier Features
26    /// This method should sample frequencies according to the spectral measure
27    /// of the kernel (i.e., the Fourier transform of the kernel)
28    fn sample_frequencies(
29        &self,
30        n_features: usize,
31        n_components: usize,
32        rng: &mut RealStdRng,
33    ) -> Array2<Float>;
34
35    /// Get a description of the kernel
36    fn description(&self) -> String;
37}
38
39/// Custom RBF kernel with configurable parameters
40#[derive(Debug, Clone)]
41/// CustomRBFKernel
42pub struct CustomRBFKernel {
43    /// gamma
44    pub gamma: Float,
45    /// sigma
46    pub sigma: Float,
47}
48
49impl CustomRBFKernel {
50    pub fn new(gamma: Float) -> Self {
51        Self {
52            gamma,
53            sigma: (1.0 / (2.0 * gamma)).sqrt(),
54        }
55    }
56
57    pub fn from_sigma(sigma: Float) -> Self {
58        let gamma = 1.0 / (2.0 * sigma * sigma);
59        Self { gamma, sigma }
60    }
61}
62
63impl KernelFunction for CustomRBFKernel {
64    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
65        let dist_sq: Float = x
66            .iter()
67            .zip(y.iter())
68            .map(|(xi, yi)| (xi - yi).powi(2))
69            .sum();
70        (-self.gamma * dist_sq).exp()
71    }
72
73    fn fourier_transform(&self, w: &[Float]) -> Float {
74        let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
75        (-w_norm_sq / (4.0 * self.gamma)).exp()
76    }
77
78    fn sample_frequencies(
79        &self,
80        n_features: usize,
81        n_components: usize,
82        rng: &mut RealStdRng,
83    ) -> Array2<Float> {
84        let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
85        let mut weights = Array2::zeros((n_features, n_components));
86        for mut col in weights.columns_mut() {
87            for val in col.iter_mut() {
88                *val = normal.sample(rng);
89            }
90        }
91        weights
92    }
93
94    fn description(&self) -> String {
95        format!("Custom RBF kernel with gamma={}", self.gamma)
96    }
97}
98
99/// Custom polynomial kernel
100#[derive(Debug, Clone)]
101/// CustomPolynomialKernel
102pub struct CustomPolynomialKernel {
103    /// gamma
104    pub gamma: Float,
105    /// coef0
106    pub coef0: Float,
107    /// degree
108    pub degree: u32,
109}
110
111impl CustomPolynomialKernel {
112    pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
113        Self {
114            gamma,
115            coef0,
116            degree,
117        }
118    }
119}
120
121impl KernelFunction for CustomPolynomialKernel {
122    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
123        let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
124        (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
125    }
126
127    fn fourier_transform(&self, w: &[Float]) -> Float {
128        // Polynomial kernels don't have a simple Fourier transform
129        // We use an approximation based on the dominant frequency component
130        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
131        (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
132    }
133
134    fn sample_frequencies(
135        &self,
136        n_features: usize,
137        n_components: usize,
138        rng: &mut RealStdRng,
139    ) -> Array2<Float> {
140        // For polynomial kernels, we sample from a scaled normal distribution
141        let normal = RandNormal::new(0.0, self.gamma.sqrt()).unwrap();
142        let mut weights = Array2::zeros((n_features, n_components));
143        for mut col in weights.columns_mut() {
144            for val in col.iter_mut() {
145                *val = normal.sample(rng);
146            }
147        }
148        weights
149    }
150
151    fn description(&self) -> String {
152        format!(
153            "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
154            self.degree, self.gamma, self.coef0
155        )
156    }
157}
158
159/// Custom Laplacian kernel
160#[derive(Debug, Clone)]
161/// CustomLaplacianKernel
162pub struct CustomLaplacianKernel {
163    /// gamma
164    pub gamma: Float,
165}
166
167impl CustomLaplacianKernel {
168    pub fn new(gamma: Float) -> Self {
169        Self { gamma }
170    }
171}
172
173impl KernelFunction for CustomLaplacianKernel {
174    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
175        let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
176        (-self.gamma * l1_dist).exp()
177    }
178
179    fn fourier_transform(&self, w: &[Float]) -> Float {
180        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
181        self.gamma / (self.gamma + w_norm).powi(2)
182    }
183
184    fn sample_frequencies(
185        &self,
186        n_features: usize,
187        n_components: usize,
188        rng: &mut RealStdRng,
189    ) -> Array2<Float> {
190        use scirs2_core::random::Cauchy;
191        let cauchy = Cauchy::new(0.0, self.gamma).unwrap();
192        let mut weights = Array2::zeros((n_features, n_components));
193        for mut col in weights.columns_mut() {
194            for val in col.iter_mut() {
195                *val = cauchy.sample(rng);
196            }
197        }
198        weights
199    }
200
201    fn description(&self) -> String {
202        format!("Custom Laplacian kernel with gamma={}", self.gamma)
203    }
204}
205
206/// Custom exponential kernel (1D version of Laplacian)
207#[derive(Debug, Clone)]
208/// CustomExponentialKernel
209pub struct CustomExponentialKernel {
210    /// length_scale
211    pub length_scale: Float,
212}
213
214impl CustomExponentialKernel {
215    pub fn new(length_scale: Float) -> Self {
216        Self { length_scale }
217    }
218}
219
220impl KernelFunction for CustomExponentialKernel {
221    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
222        let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
223        (-dist / self.length_scale).exp()
224    }
225
226    fn fourier_transform(&self, w: &[Float]) -> Float {
227        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
228        2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
229    }
230
231    fn sample_frequencies(
232        &self,
233        n_features: usize,
234        n_components: usize,
235        rng: &mut RealStdRng,
236    ) -> Array2<Float> {
237        use scirs2_core::random::Cauchy;
238        let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).unwrap();
239        let mut weights = Array2::zeros((n_features, n_components));
240        for mut col in weights.columns_mut() {
241            for val in col.iter_mut() {
242                *val = cauchy.sample(rng);
243            }
244        }
245        weights
246    }
247
248    fn description(&self) -> String {
249        format!(
250            "Custom Exponential kernel with length_scale={}",
251            self.length_scale
252        )
253    }
254}
255
256/// Custom kernel random feature generator
257///
258/// Generates random Fourier features for any custom kernel function that implements
259/// the KernelFunction trait. This provides a flexible framework for kernel approximation
260/// with user-defined kernels.
261///
262/// # Parameters
263///
264/// * `kernel` - Custom kernel function implementing KernelFunction trait
265/// * `n_components` - Number of random features to generate (default: 100)
266/// * `random_state` - Random seed for reproducibility
267///
268/// # Examples
269///
270/// ```rust,ignore
271/// use sklears_kernel_approximation::custom_kernel::{CustomKernelSampler, CustomRBFKernel};
272/// use sklears_core::traits::{Transform, Fit, Untrained}
273/// use scirs2_core::ndarray::array;
274///
275/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
276/// let kernel = CustomRBFKernel::new(0.1);
277///
278/// let sampler = CustomKernelSampler::new(kernel, 100);
279/// let fitted_sampler = sampler.fit(&X, &()).unwrap();
280/// let X_transformed = fitted_sampler.transform(&X).unwrap();
281/// assert_eq!(X_transformed.shape(), &[3, 100]);
282/// ```
283#[derive(Debug, Clone)]
284/// CustomKernelSampler
285pub struct CustomKernelSampler<K, State = Untrained>
286where
287    K: KernelFunction,
288{
289    /// Custom kernel function
290    pub kernel: K,
291    /// Number of random features
292    pub n_components: usize,
293    /// Random seed
294    pub random_state: Option<u64>,
295
296    // Fitted attributes
297    random_weights_: Option<Array2<Float>>,
298    random_offset_: Option<Array1<Float>>,
299
300    _state: PhantomData<State>,
301}
302
303impl<K> CustomKernelSampler<K, Untrained>
304where
305    K: KernelFunction,
306{
307    /// Create a new custom kernel sampler
308    pub fn new(kernel: K, n_components: usize) -> Self {
309        Self {
310            kernel,
311            n_components,
312            random_state: None,
313            random_weights_: None,
314            random_offset_: None,
315            _state: PhantomData,
316        }
317    }
318
319    /// Set random state for reproducibility
320    pub fn random_state(mut self, seed: u64) -> Self {
321        self.random_state = Some(seed);
322        self
323    }
324}
325
326impl<K> Estimator for CustomKernelSampler<K, Untrained>
327where
328    K: KernelFunction,
329{
330    type Config = ();
331    type Error = SklearsError;
332    type Float = Float;
333
334    fn config(&self) -> &Self::Config {
335        &()
336    }
337}
338
339impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
340where
341    K: KernelFunction,
342{
343    type Fitted = CustomKernelSampler<K, Trained>;
344
345    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
346        let (_, n_features) = x.dim();
347
348        if self.n_components == 0 {
349            return Err(SklearsError::InvalidInput(
350                "n_components must be positive".to_string(),
351            ));
352        }
353
354        let mut rng = if let Some(seed) = self.random_state {
355            RealStdRng::seed_from_u64(seed)
356        } else {
357            RealStdRng::from_seed(thread_rng().gen())
358        };
359
360        // Sample random frequencies using the kernel's sampling method
361        let random_weights =
362            self.kernel
363                .sample_frequencies(n_features, self.n_components, &mut rng);
364
365        // Sample random offsets from Uniform(0, 2π)
366        let uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
367        let mut random_offset = Array1::zeros(self.n_components);
368        for val in random_offset.iter_mut() {
369            *val = rng.sample(uniform);
370        }
371
372        Ok(CustomKernelSampler {
373            kernel: self.kernel,
374            n_components: self.n_components,
375            random_state: self.random_state,
376            random_weights_: Some(random_weights),
377            random_offset_: Some(random_offset),
378            _state: PhantomData,
379        })
380    }
381}
382
383impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
384where
385    K: KernelFunction,
386{
387    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
388        let (_n_samples, n_features) = x.dim();
389        let weights = self.random_weights_.as_ref().unwrap();
390        let offset = self.random_offset_.as_ref().unwrap();
391
392        if n_features != weights.nrows() {
393            return Err(SklearsError::InvalidInput(format!(
394                "X has {} features, but CustomKernelSampler was fitted with {} features",
395                n_features,
396                weights.nrows()
397            )));
398        }
399
400        // Compute projection: X @ weights + offset
401        let projection = x.dot(weights) + offset.view().insert_axis(Axis(0));
402
403        // Apply cosine and normalize: sqrt(2/n_components) * cos(projection)
404        let normalization = (2.0 / self.n_components as Float).sqrt();
405        let result = projection.mapv(|v| normalization * v.cos());
406
407        Ok(result)
408    }
409}
410
411impl<K> CustomKernelSampler<K, Trained>
412where
413    K: KernelFunction,
414{
415    /// Get the random weights
416    pub fn random_weights(&self) -> &Array2<Float> {
417        self.random_weights_.as_ref().unwrap()
418    }
419
420    /// Get the random offset
421    pub fn random_offset(&self) -> &Array1<Float> {
422        self.random_offset_.as_ref().unwrap()
423    }
424
425    /// Get the kernel description
426    pub fn kernel_description(&self) -> String {
427        self.kernel.description()
428    }
429
430    /// Compute exact kernel matrix for comparison/evaluation
431    pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
432        let (n_x, _) = x.dim();
433        let (n_y, _) = y.dim();
434        let mut kernel_matrix = Array2::zeros((n_x, n_y));
435
436        for i in 0..n_x {
437            for j in 0..n_y {
438                let x_row = x.row(i).to_vec();
439                let y_row = y.row(j).to_vec();
440                kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
441            }
442        }
443
444        kernel_matrix
445    }
446}
447
448#[allow(non_snake_case)]
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use approx::assert_abs_diff_eq;
453    use scirs2_core::ndarray::array;
454
455    #[test]
456    fn test_custom_rbf_kernel() {
457        let kernel = CustomRBFKernel::new(0.5);
458        let x = vec![1.0, 2.0];
459        let y = vec![1.0, 2.0];
460
461        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
462
463        let y2 = vec![2.0, 3.0];
464        let expected = (-0.5_f64 * 2.0).exp(); // dist_sq = (1-2)² + (2-3)² = 2
465        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
466    }
467
468    #[test]
469    fn test_custom_polynomial_kernel() {
470        let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
471        let x = vec![1.0, 2.0];
472        let y = vec![2.0, 3.0];
473
474        let dot_product = 1.0 * 2.0 + 2.0 * 3.0; // = 8
475        let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); // = 9² = 81
476        assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
477    }
478
479    #[test]
480    fn test_custom_laplacian_kernel() {
481        let kernel = CustomLaplacianKernel::new(0.5);
482        let x = vec![1.0, 2.0];
483        let y = vec![1.0, 2.0];
484
485        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
486
487        let y2 = vec![2.0, 4.0];
488        let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); // = 1 + 2 = 3
489        let expected = (-0.5_f64 * l1_dist).exp(); // = exp(-1.5)
490        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
491    }
492
493    #[test]
494    fn test_custom_kernel_sampler_basic() {
495        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
496        let kernel = CustomRBFKernel::new(0.1);
497
498        let sampler = CustomKernelSampler::new(kernel, 50);
499        let fitted = sampler.fit(&x, &()).unwrap();
500        let x_transformed = fitted.transform(&x).unwrap();
501
502        assert_eq!(x_transformed.shape(), &[3, 50]);
503
504        // Check that values are in reasonable range for cosine function
505        for val in x_transformed.iter() {
506            assert!(val.abs() <= 2.0); // sqrt(2) * 1 is the max possible value
507        }
508    }
509
510    #[test]
511    fn test_custom_kernel_sampler_reproducibility() {
512        let x = array![[1.0, 2.0], [3.0, 4.0]];
513        let kernel1 = CustomRBFKernel::new(0.1);
514        let kernel2 = CustomRBFKernel::new(0.1);
515
516        let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
517        let fitted1 = sampler1.fit(&x, &()).unwrap();
518        let result1 = fitted1.transform(&x).unwrap();
519
520        let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
521        let fitted2 = sampler2.fit(&x, &()).unwrap();
522        let result2 = fitted2.transform(&x).unwrap();
523
524        // Results should be identical with same random state
525        for (a, b) in result1.iter().zip(result2.iter()) {
526            assert!((a - b).abs() < 1e-10);
527        }
528    }
529
530    #[test]
531    fn test_custom_kernel_sampler_different_kernels() {
532        let x = array![[1.0, 2.0], [3.0, 4.0]];
533
534        // Test with different kernel types
535        let rbf_kernel = CustomRBFKernel::new(0.1);
536        let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
537        let fitted_rbf = rbf_sampler.fit(&x, &()).unwrap();
538        let result_rbf = fitted_rbf.transform(&x).unwrap();
539        assert_eq!(result_rbf.shape(), &[2, 10]);
540
541        let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
542        let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
543        let fitted_poly = poly_sampler.fit(&x, &()).unwrap();
544        let result_poly = fitted_poly.transform(&x).unwrap();
545        assert_eq!(result_poly.shape(), &[2, 10]);
546
547        let lap_kernel = CustomLaplacianKernel::new(0.5);
548        let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
549        let fitted_lap = lap_sampler.fit(&x, &()).unwrap();
550        let result_lap = fitted_lap.transform(&x).unwrap();
551        assert_eq!(result_lap.shape(), &[2, 10]);
552    }
553
554    #[test]
555    fn test_exact_kernel_matrix_computation() {
556        let x = array![[1.0, 2.0], [3.0, 4.0]];
557        let y = array![[1.0, 2.0], [5.0, 6.0]];
558        let kernel = CustomRBFKernel::new(0.5);
559
560        let sampler = CustomKernelSampler::new(kernel.clone(), 10);
561        let fitted = sampler.fit(&x, &()).unwrap();
562        let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
563
564        assert_eq!(kernel_matrix.shape(), &[2, 2]);
565
566        // Check diagonal elements (should be 1.0 for RBF kernel with same points)
567        assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
568
569        // Manual verification for one element
570        let x1 = vec![1.0, 2.0];
571        let y2 = vec![5.0, 6.0];
572        let expected = kernel.kernel(&x1, &y2);
573        assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
574    }
575
576    #[test]
577    fn test_custom_kernel_feature_mismatch() {
578        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
579        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
580
581        let kernel = CustomRBFKernel::new(0.1);
582        let sampler = CustomKernelSampler::new(kernel, 10);
583        let fitted = sampler.fit(&x_train, &()).unwrap();
584        let result = fitted.transform(&x_test);
585
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn test_custom_kernel_zero_components() {
591        let x = array![[1.0, 2.0]];
592        let kernel = CustomRBFKernel::new(0.1);
593        let sampler = CustomKernelSampler::new(kernel, 0);
594        let result = sampler.fit(&x, &());
595        assert!(result.is_err());
596    }
597}