Skip to main content

sklears_kernel_approximation/
custom_kernel.rs

1//! Custom kernel random feature generation framework
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Distribution;
6use scirs2_core::random::RngExt;
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Trained, Transform, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14use scirs2_core::random::{thread_rng, SeedableRng};
15/// Trait for defining custom kernel functions
16pub trait KernelFunction: Clone + Send + Sync {
17    /// Compute the kernel value between two vectors
18    fn kernel(&self, x: &[Float], y: &[Float]) -> Float;
19
20    /// Get the characteristic function of the kernel's Fourier transform
21    /// This should return the Fourier transform of the kernel at frequency w
22    /// For most kernels, this is needed to generate appropriate random features
23    fn fourier_transform(&self, w: &[Float]) -> Float;
24
25    /// Sample random frequencies for Random Fourier Features
26    /// This method should sample frequencies according to the spectral measure
27    /// of the kernel (i.e., the Fourier transform of the kernel)
28    fn sample_frequencies(
29        &self,
30        n_features: usize,
31        n_components: usize,
32        rng: &mut RealStdRng,
33    ) -> Array2<Float>;
34
35    /// Get a description of the kernel
36    fn description(&self) -> String;
37}
38
39/// Custom RBF kernel with configurable parameters
40#[derive(Debug, Clone)]
41/// CustomRBFKernel
42pub struct CustomRBFKernel {
43    /// gamma
44    pub gamma: Float,
45    /// sigma
46    pub sigma: Float,
47}
48
49impl CustomRBFKernel {
50    pub fn new(gamma: Float) -> Self {
51        Self {
52            gamma,
53            sigma: (1.0 / (2.0 * gamma)).sqrt(),
54        }
55    }
56
57    pub fn from_sigma(sigma: Float) -> Self {
58        let gamma = 1.0 / (2.0 * sigma * sigma);
59        Self { gamma, sigma }
60    }
61}
62
63impl KernelFunction for CustomRBFKernel {
64    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
65        let dist_sq: Float = x
66            .iter()
67            .zip(y.iter())
68            .map(|(xi, yi)| (xi - yi).powi(2))
69            .sum();
70        (-self.gamma * dist_sq).exp()
71    }
72
73    fn fourier_transform(&self, w: &[Float]) -> Float {
74        let w_norm_sq: Float = w.iter().map(|wi| wi.powi(2)).sum();
75        (-w_norm_sq / (4.0 * self.gamma)).exp()
76    }
77
78    fn sample_frequencies(
79        &self,
80        n_features: usize,
81        n_components: usize,
82        rng: &mut RealStdRng,
83    ) -> Array2<Float> {
84        let normal =
85            RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).expect("operation should succeed");
86        let mut weights = Array2::zeros((n_features, n_components));
87        for mut col in weights.columns_mut() {
88            for val in col.iter_mut() {
89                *val = normal.sample(rng);
90            }
91        }
92        weights
93    }
94
95    fn description(&self) -> String {
96        format!("Custom RBF kernel with gamma={}", self.gamma)
97    }
98}
99
100/// Custom polynomial kernel
101#[derive(Debug, Clone)]
102/// CustomPolynomialKernel
103pub struct CustomPolynomialKernel {
104    /// gamma
105    pub gamma: Float,
106    /// coef0
107    pub coef0: Float,
108    /// degree
109    pub degree: u32,
110}
111
112impl CustomPolynomialKernel {
113    pub fn new(degree: u32, gamma: Float, coef0: Float) -> Self {
114        Self {
115            gamma,
116            coef0,
117            degree,
118        }
119    }
120}
121
122impl KernelFunction for CustomPolynomialKernel {
123    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
124        let dot_product: Float = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
125        (self.gamma * dot_product + self.coef0).powf(self.degree as Float)
126    }
127
128    fn fourier_transform(&self, w: &[Float]) -> Float {
129        // Polynomial kernels don't have a simple Fourier transform
130        // We use an approximation based on the dominant frequency component
131        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
132        (1.0 + w_norm * self.gamma).powf(-(self.degree as Float))
133    }
134
135    fn sample_frequencies(
136        &self,
137        n_features: usize,
138        n_components: usize,
139        rng: &mut RealStdRng,
140    ) -> Array2<Float> {
141        // For polynomial kernels, we sample from a scaled normal distribution
142        let normal = RandNormal::new(0.0, self.gamma.sqrt()).expect("operation should succeed");
143        let mut weights = Array2::zeros((n_features, n_components));
144        for mut col in weights.columns_mut() {
145            for val in col.iter_mut() {
146                *val = normal.sample(rng);
147            }
148        }
149        weights
150    }
151
152    fn description(&self) -> String {
153        format!(
154            "Custom Polynomial kernel with degree={}, gamma={}, coef0={}",
155            self.degree, self.gamma, self.coef0
156        )
157    }
158}
159
160/// Custom Laplacian kernel
161#[derive(Debug, Clone)]
162/// CustomLaplacianKernel
163pub struct CustomLaplacianKernel {
164    /// gamma
165    pub gamma: Float,
166}
167
168impl CustomLaplacianKernel {
169    pub fn new(gamma: Float) -> Self {
170        Self { gamma }
171    }
172}
173
174impl KernelFunction for CustomLaplacianKernel {
175    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
176        let l1_dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
177        (-self.gamma * l1_dist).exp()
178    }
179
180    fn fourier_transform(&self, w: &[Float]) -> Float {
181        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
182        self.gamma / (self.gamma + w_norm).powi(2)
183    }
184
185    fn sample_frequencies(
186        &self,
187        n_features: usize,
188        n_components: usize,
189        rng: &mut RealStdRng,
190    ) -> Array2<Float> {
191        use scirs2_core::random::Cauchy;
192        let cauchy = Cauchy::new(0.0, self.gamma).expect("operation should succeed");
193        let mut weights = Array2::zeros((n_features, n_components));
194        for mut col in weights.columns_mut() {
195            for val in col.iter_mut() {
196                *val = cauchy.sample(rng);
197            }
198        }
199        weights
200    }
201
202    fn description(&self) -> String {
203        format!("Custom Laplacian kernel with gamma={}", self.gamma)
204    }
205}
206
207/// Custom exponential kernel (1D version of Laplacian)
208#[derive(Debug, Clone)]
209/// CustomExponentialKernel
210pub struct CustomExponentialKernel {
211    /// length_scale
212    pub length_scale: Float,
213}
214
215impl CustomExponentialKernel {
216    pub fn new(length_scale: Float) -> Self {
217        Self { length_scale }
218    }
219}
220
221impl KernelFunction for CustomExponentialKernel {
222    fn kernel(&self, x: &[Float], y: &[Float]) -> Float {
223        let dist: Float = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
224        (-dist / self.length_scale).exp()
225    }
226
227    fn fourier_transform(&self, w: &[Float]) -> Float {
228        let w_norm: Float = w.iter().map(|wi| wi.abs()).sum();
229        2.0 * self.length_scale / (1.0 + (self.length_scale * w_norm).powi(2))
230    }
231
232    fn sample_frequencies(
233        &self,
234        n_features: usize,
235        n_components: usize,
236        rng: &mut RealStdRng,
237    ) -> Array2<Float> {
238        use scirs2_core::random::Cauchy;
239        let cauchy = Cauchy::new(0.0, 1.0 / self.length_scale).expect("operation should succeed");
240        let mut weights = Array2::zeros((n_features, n_components));
241        for mut col in weights.columns_mut() {
242            for val in col.iter_mut() {
243                *val = cauchy.sample(rng);
244            }
245        }
246        weights
247    }
248
249    fn description(&self) -> String {
250        format!(
251            "Custom Exponential kernel with length_scale={}",
252            self.length_scale
253        )
254    }
255}
256
257/// Custom kernel random feature generator
258///
259/// Generates random Fourier features for any custom kernel function that implements
260/// the KernelFunction trait. This provides a flexible framework for kernel approximation
261/// with user-defined kernels.
262///
263/// # Parameters
264///
265/// * `kernel` - Custom kernel function implementing KernelFunction trait
266/// * `n_components` - Number of random features to generate (default: 100)
267/// * `random_state` - Random seed for reproducibility
268///
269/// # Examples
270///
271/// ```rust,ignore
272/// use sklears_kernel_approximation::custom_kernel::{CustomKernelSampler, CustomRBFKernel};
273/// use sklears_core::traits::{Transform, Fit, Untrained}
274/// use scirs2_core::ndarray::array;
275///
276/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
277/// let kernel = CustomRBFKernel::new(0.1);
278///
279/// let sampler = CustomKernelSampler::new(kernel, 100);
280/// let fitted_sampler = sampler.fit(&X, &()).unwrap();
281/// let X_transformed = fitted_sampler.transform(&X).unwrap();
282/// assert_eq!(X_transformed.shape(), &[3, 100]);
283/// ```
284#[derive(Debug, Clone)]
285/// CustomKernelSampler
286pub struct CustomKernelSampler<K, State = Untrained>
287where
288    K: KernelFunction,
289{
290    /// Custom kernel function
291    pub kernel: K,
292    /// Number of random features
293    pub n_components: usize,
294    /// Random seed
295    pub random_state: Option<u64>,
296
297    // Fitted attributes
298    random_weights_: Option<Array2<Float>>,
299    random_offset_: Option<Array1<Float>>,
300
301    _state: PhantomData<State>,
302}
303
304impl<K> CustomKernelSampler<K, Untrained>
305where
306    K: KernelFunction,
307{
308    /// Create a new custom kernel sampler
309    pub fn new(kernel: K, n_components: usize) -> Self {
310        Self {
311            kernel,
312            n_components,
313            random_state: None,
314            random_weights_: None,
315            random_offset_: None,
316            _state: PhantomData,
317        }
318    }
319
320    /// Set random state for reproducibility
321    pub fn random_state(mut self, seed: u64) -> Self {
322        self.random_state = Some(seed);
323        self
324    }
325}
326
327impl<K> Estimator for CustomKernelSampler<K, Untrained>
328where
329    K: KernelFunction,
330{
331    type Config = ();
332    type Error = SklearsError;
333    type Float = Float;
334
335    fn config(&self) -> &Self::Config {
336        &()
337    }
338}
339
340impl<K> Fit<Array2<Float>, ()> for CustomKernelSampler<K, Untrained>
341where
342    K: KernelFunction,
343{
344    type Fitted = CustomKernelSampler<K, Trained>;
345
346    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
347        let (_, n_features) = x.dim();
348
349        if self.n_components == 0 {
350            return Err(SklearsError::InvalidInput(
351                "n_components must be positive".to_string(),
352            ));
353        }
354
355        let mut rng = if let Some(seed) = self.random_state {
356            RealStdRng::seed_from_u64(seed)
357        } else {
358            RealStdRng::from_seed(thread_rng().random())
359        };
360
361        // Sample random frequencies using the kernel's sampling method
362        let random_weights =
363            self.kernel
364                .sample_frequencies(n_features, self.n_components, &mut rng);
365
366        // Sample random offsets from Uniform(0, 2π)
367        let uniform =
368            RandUniform::new(0.0, 2.0 * std::f64::consts::PI).expect("operation should succeed");
369        let mut random_offset = Array1::zeros(self.n_components);
370        for val in random_offset.iter_mut() {
371            *val = rng.sample(uniform);
372        }
373
374        Ok(CustomKernelSampler {
375            kernel: self.kernel,
376            n_components: self.n_components,
377            random_state: self.random_state,
378            random_weights_: Some(random_weights),
379            random_offset_: Some(random_offset),
380            _state: PhantomData,
381        })
382    }
383}
384
385impl<K> Transform<Array2<Float>, Array2<Float>> for CustomKernelSampler<K, Trained>
386where
387    K: KernelFunction,
388{
389    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
390        let (_n_samples, n_features) = x.dim();
391        let weights = self
392            .random_weights_
393            .as_ref()
394            .expect("operation should succeed");
395        let offset = self
396            .random_offset_
397            .as_ref()
398            .expect("operation should succeed");
399
400        if n_features != weights.nrows() {
401            return Err(SklearsError::InvalidInput(format!(
402                "X has {} features, but CustomKernelSampler was fitted with {} features",
403                n_features,
404                weights.nrows()
405            )));
406        }
407
408        // Compute projection: X @ weights + offset
409        let projection = x.dot(weights) + offset.view().insert_axis(Axis(0));
410
411        // Apply cosine and normalize: sqrt(2/n_components) * cos(projection)
412        let normalization = (2.0 / self.n_components as Float).sqrt();
413        let result = projection.mapv(|v| normalization * v.cos());
414
415        Ok(result)
416    }
417}
418
419impl<K> CustomKernelSampler<K, Trained>
420where
421    K: KernelFunction,
422{
423    /// Get the random weights
424    pub fn random_weights(&self) -> &Array2<Float> {
425        self.random_weights_
426            .as_ref()
427            .expect("operation should succeed")
428    }
429
430    /// Get the random offset
431    pub fn random_offset(&self) -> &Array1<Float> {
432        self.random_offset_
433            .as_ref()
434            .expect("operation should succeed")
435    }
436
437    /// Get the kernel description
438    pub fn kernel_description(&self) -> String {
439        self.kernel.description()
440    }
441
442    /// Compute exact kernel matrix for comparison/evaluation
443    pub fn exact_kernel_matrix(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
444        let (n_x, _) = x.dim();
445        let (n_y, _) = y.dim();
446        let mut kernel_matrix = Array2::zeros((n_x, n_y));
447
448        for i in 0..n_x {
449            for j in 0..n_y {
450                let x_row = x.row(i).to_vec();
451                let y_row = y.row(j).to_vec();
452                kernel_matrix[[i, j]] = self.kernel.kernel(&x_row, &y_row);
453            }
454        }
455
456        kernel_matrix
457    }
458}
459
460#[allow(non_snake_case)]
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use approx::assert_abs_diff_eq;
465    use scirs2_core::ndarray::array;
466
467    #[test]
468    fn test_custom_rbf_kernel() {
469        let kernel = CustomRBFKernel::new(0.5);
470        let x = vec![1.0, 2.0];
471        let y = vec![1.0, 2.0];
472
473        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
474
475        let y2 = vec![2.0, 3.0];
476        let expected = (-0.5_f64 * 2.0).exp(); // dist_sq = (1-2)² + (2-3)² = 2
477        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
478    }
479
480    #[test]
481    fn test_custom_polynomial_kernel() {
482        let kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
483        let x = vec![1.0, 2.0];
484        let y = vec![2.0, 3.0];
485
486        let dot_product = 1.0 * 2.0 + 2.0 * 3.0; // = 8
487        let expected = (1.0_f64 * dot_product + 1.0).powf(2.0); // = 9² = 81
488        assert_abs_diff_eq!(kernel.kernel(&x, &y), expected, epsilon = 1e-10);
489    }
490
491    #[test]
492    fn test_custom_laplacian_kernel() {
493        let kernel = CustomLaplacianKernel::new(0.5);
494        let x = vec![1.0, 2.0];
495        let y = vec![1.0, 2.0];
496
497        assert_abs_diff_eq!(kernel.kernel(&x, &y), 1.0, epsilon = 1e-10);
498
499        let y2 = vec![2.0, 4.0];
500        let l1_dist = (1.0_f64 - 2.0).abs() + (2.0_f64 - 4.0).abs(); // = 1 + 2 = 3
501        let expected = (-0.5_f64 * l1_dist).exp(); // = exp(-1.5)
502        assert_abs_diff_eq!(kernel.kernel(&x, &y2), expected, epsilon = 1e-10);
503    }
504
505    #[test]
506    fn test_custom_kernel_sampler_basic() {
507        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
508        let kernel = CustomRBFKernel::new(0.1);
509
510        let sampler = CustomKernelSampler::new(kernel, 50);
511        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
512        let x_transformed = fitted.transform(&x).expect("operation should succeed");
513
514        assert_eq!(x_transformed.shape(), &[3, 50]);
515
516        // Check that values are in reasonable range for cosine function
517        for val in x_transformed.iter() {
518            assert!(val.abs() <= 2.0); // sqrt(2) * 1 is the max possible value
519        }
520    }
521
522    #[test]
523    fn test_custom_kernel_sampler_reproducibility() {
524        let x = array![[1.0, 2.0], [3.0, 4.0]];
525        let kernel1 = CustomRBFKernel::new(0.1);
526        let kernel2 = CustomRBFKernel::new(0.1);
527
528        let sampler1 = CustomKernelSampler::new(kernel1, 10).random_state(42);
529        let fitted1 = sampler1.fit(&x, &()).expect("operation should succeed");
530        let result1 = fitted1.transform(&x).expect("operation should succeed");
531
532        let sampler2 = CustomKernelSampler::new(kernel2, 10).random_state(42);
533        let fitted2 = sampler2.fit(&x, &()).expect("operation should succeed");
534        let result2 = fitted2.transform(&x).expect("operation should succeed");
535
536        // Results should be identical with same random state
537        for (a, b) in result1.iter().zip(result2.iter()) {
538            assert!((a - b).abs() < 1e-10);
539        }
540    }
541
542    #[test]
543    fn test_custom_kernel_sampler_different_kernels() {
544        let x = array![[1.0, 2.0], [3.0, 4.0]];
545
546        // Test with different kernel types
547        let rbf_kernel = CustomRBFKernel::new(0.1);
548        let rbf_sampler = CustomKernelSampler::new(rbf_kernel, 10);
549        let fitted_rbf = rbf_sampler.fit(&x, &()).expect("operation should succeed");
550        let result_rbf = fitted_rbf.transform(&x).expect("operation should succeed");
551        assert_eq!(result_rbf.shape(), &[2, 10]);
552
553        let poly_kernel = CustomPolynomialKernel::new(2, 1.0, 1.0);
554        let poly_sampler = CustomKernelSampler::new(poly_kernel, 10);
555        let fitted_poly = poly_sampler.fit(&x, &()).expect("operation should succeed");
556        let result_poly = fitted_poly.transform(&x).expect("operation should succeed");
557        assert_eq!(result_poly.shape(), &[2, 10]);
558
559        let lap_kernel = CustomLaplacianKernel::new(0.5);
560        let lap_sampler = CustomKernelSampler::new(lap_kernel, 10);
561        let fitted_lap = lap_sampler.fit(&x, &()).expect("operation should succeed");
562        let result_lap = fitted_lap.transform(&x).expect("operation should succeed");
563        assert_eq!(result_lap.shape(), &[2, 10]);
564    }
565
566    #[test]
567    fn test_exact_kernel_matrix_computation() {
568        let x = array![[1.0, 2.0], [3.0, 4.0]];
569        let y = array![[1.0, 2.0], [5.0, 6.0]];
570        let kernel = CustomRBFKernel::new(0.5);
571
572        let sampler = CustomKernelSampler::new(kernel.clone(), 10);
573        let fitted = sampler.fit(&x, &()).expect("operation should succeed");
574        let kernel_matrix = fitted.exact_kernel_matrix(&x, &y);
575
576        assert_eq!(kernel_matrix.shape(), &[2, 2]);
577
578        // Check diagonal elements (should be 1.0 for RBF kernel with same points)
579        assert_abs_diff_eq!(kernel_matrix[[0, 0]], 1.0, epsilon = 1e-10);
580
581        // Manual verification for one element
582        let x1 = vec![1.0, 2.0];
583        let y2 = vec![5.0, 6.0];
584        let expected = kernel.kernel(&x1, &y2);
585        assert_abs_diff_eq!(kernel_matrix[[0, 1]], expected, epsilon = 1e-10);
586    }
587
588    #[test]
589    fn test_custom_kernel_feature_mismatch() {
590        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
591        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
592
593        let kernel = CustomRBFKernel::new(0.1);
594        let sampler = CustomKernelSampler::new(kernel, 10);
595        let fitted = sampler
596            .fit(&x_train, &())
597            .expect("operation should succeed");
598        let result = fitted.transform(&x_test);
599
600        assert!(result.is_err());
601    }
602
603    #[test]
604    fn test_custom_kernel_zero_components() {
605        let x = array![[1.0, 2.0]];
606        let kernel = CustomRBFKernel::new(0.1);
607        let sampler = CustomKernelSampler::new(kernel, 0);
608        let result = sampler.fit(&x, &());
609        assert!(result.is_err());
610    }
611}