sklears_kernel_approximation/
deep_learning_kernels.rs

1//! Deep Learning Integration for Kernel Approximation
2//!
3//! This module implements advanced kernel methods inspired by deep learning,
4//! including Neural Tangent Kernels (NTK), Deep Kernel Learning, and
5//! infinite-width network approximations.
6//!
7//! # References
8//! - Jacot et al. (2018): "Neural Tangent Kernel: Convergence and Generalization in Neural Networks"
9//! - Wilson et al. (2016): "Deep Kernel Learning"
10//! - Lee et al. (2018): "Deep Neural Networks as Gaussian Processes"
11//! - Arora et al. (2019): "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks"
12
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::{Normal, Uniform};
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18    error::{Result, SklearsError},
19    prelude::{Fit, Transform},
20    traits::{Estimator, Trained, Untrained},
21    types::Float,
22};
23use std::marker::PhantomData;
24
25const PI: Float = std::f64::consts::PI;
26
27/// Activation functions for neural network kernels
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum Activation {
30    /// ReLU activation: max(0, x)
31    ReLU,
32    /// Tanh activation
33    Tanh,
34    /// Sigmoid activation: 1 / (1 + exp(-x))
35    Sigmoid,
36    /// Error function (erf)
37    Erf,
38    /// Linear activation (identity)
39    Linear,
40    /// GELU activation
41    GELU,
42    /// Swish activation: x * sigmoid(x)
43    Swish,
44}
45
46impl Activation {
47    /// Apply activation function element-wise
48    pub fn apply(&self, x: Float) -> Float {
49        match self {
50            Activation::ReLU => x.max(0.0),
51            Activation::Tanh => x.tanh(),
52            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
53            Activation::Erf => {
54                // Abramowitz and Stegun approximation for erf(x)
55                let sign = if x >= 0.0 { 1.0 } else { -1.0 };
56                let x_abs = x.abs();
57                let t = 1.0 / (1.0 + 0.3275911 * x_abs);
58                let approx = 1.0
59                    - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736)
60                        * t
61                        + 0.254829592)
62                        * t
63                        * (-x_abs * x_abs).exp();
64                sign * approx
65            }
66            Activation::Linear => x,
67            Activation::GELU => {
68                // GELU(x) = x * Φ(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
69                let sqrt_2_over_pi = (2.0 / PI).sqrt();
70                0.5 * x * (1.0 + (sqrt_2_over_pi * (x + 0.044715 * x.powi(3))).tanh())
71            }
72            Activation::Swish => {
73                let sigmoid = 1.0 / (1.0 + (-x).exp());
74                x * sigmoid
75            }
76        }
77    }
78
79    /// Compute the kernel for this activation at correlation rho
80    /// This is used in NTK computation
81    pub fn kernel_value(&self, rho: Float) -> Float {
82        match self {
83            Activation::ReLU => {
84                // K(x, x') = ||x|| ||x'|| / (2π) * (sin(θ) + (π - θ) cos(θ))
85                // where cos(θ) = rho / (||x|| ||x'||)
86                let theta = rho.max(-1.0).min(1.0).acos();
87                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI)
88            }
89            Activation::Tanh => {
90                // For tanh, use approximation from literature
91                2.0 / PI * (rho * (1.0 + rho.powi(2)).sqrt()).asin()
92            }
93            Activation::Erf => {
94                // For erf activation
95                2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2)).sqrt())).asin()
96            }
97            Activation::Linear => rho,
98            Activation::Sigmoid => {
99                // Approximation for sigmoid kernel
100                2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2).abs()).sqrt())).asin()
101            }
102            Activation::GELU => {
103                // GELU kernel approximation
104                let theta = rho.max(-1.0).min(1.0).acos();
105                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.702
106            }
107            Activation::Swish => {
108                // Swish kernel approximation (similar to GELU)
109                let theta = rho.max(-1.0).min(1.0).acos();
110                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.5
111            }
112        }
113    }
114}
115
116/// Configuration for Neural Tangent Kernel
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct NTKConfig {
119    /// Number of layers in the neural network
120    pub n_layers: usize,
121    /// Width of hidden layers (for finite-width approximation)
122    pub hidden_width: Option<usize>,
123    /// Activation function
124    pub activation: Activation,
125    /// Whether to use the infinite-width limit
126    pub infinite_width: bool,
127    /// Variance of weight initialization
128    pub weight_variance: Float,
129    /// Variance of bias initialization
130    pub bias_variance: Float,
131}
132
133impl Default for NTKConfig {
134    fn default() -> Self {
135        Self {
136            n_layers: 3,
137            hidden_width: Some(1024),
138            activation: Activation::ReLU,
139            infinite_width: true,
140            weight_variance: 1.0,
141            bias_variance: 1.0,
142        }
143    }
144}
145
146/// Neural Tangent Kernel (NTK) Approximation
147///
148/// The Neural Tangent Kernel describes the evolution of an infinitely-wide neural network
149/// during gradient descent. This implementation provides both exact infinite-width
150/// computation and finite-width approximations.
151///
152/// # Mathematical Background
153///
154/// For a fully-connected neural network with L layers and activation σ:
155/// - The NTK is defined as: Θ(x, x') = E[∂f/∂θ(x) · ∂f/∂θ(x')]
156/// - In the infinite-width limit, the NTK remains constant during training
157/// - The network evolution follows: df/dt = Θ(X, X) ∇f L(f)
158///
159/// # Examples
160///
161/// ```rust,ignore
162/// use sklears_kernel_approximation::deep_learning_kernels::{NeuralTangentKernel, NTKConfig, Activation};
163/// use scirs2_core::ndarray::array;
164/// use sklears_core::traits::{Fit, Transform};
165///
166/// let config = NTKConfig {
167///     n_layers: 3,
168///     activation: Activation::ReLU,
169///     infinite_width: true,
170///     ..Default::default()
171///  };
172///
173/// let ntk = NeuralTangentKernel::new(config);
174/// let X = array![[1.0, 2.0], [3.0, 4.0]];
175/// let fitted = ntk.fit(&X, &()).unwrap();
176/// let features = fitted.transform(&X).unwrap();
177/// assert_eq!(features.shape()[0], 2);
178/// ```
179#[derive(Debug, Clone)]
180pub struct NeuralTangentKernel<State = Untrained> {
181    config: NTKConfig,
182    n_components: usize,
183
184    // Fitted attributes
185    x_train: Option<Array2<Float>>,
186    eigenvectors: Option<Array2<Float>>,
187
188    _state: PhantomData<State>,
189}
190
191impl NeuralTangentKernel<Untrained> {
192    /// Create a new Neural Tangent Kernel with the given configuration
193    pub fn new(config: NTKConfig) -> Self {
194        Self {
195            config,
196            n_components: 100,
197            x_train: None,
198            eigenvectors: None,
199            _state: PhantomData,
200        }
201    }
202
203    /// Create a new NTK with default configuration
204    pub fn with_layers(n_layers: usize) -> Self {
205        Self {
206            config: NTKConfig {
207                n_layers,
208                ..Default::default()
209            },
210            n_components: 100,
211            x_train: None,
212            eigenvectors: None,
213            _state: PhantomData,
214        }
215    }
216
217    /// Set activation function
218    pub fn activation(mut self, activation: Activation) -> Self {
219        self.config.activation = activation;
220        self
221    }
222
223    /// Set whether to use infinite-width limit
224    pub fn infinite_width(mut self, infinite: bool) -> Self {
225        self.config.infinite_width = infinite;
226        self
227    }
228
229    /// Set number of components
230    pub fn n_components(mut self, n: usize) -> Self {
231        self.n_components = n;
232        self
233    }
234
235    /// Compute the NTK kernel matrix between two sets of points
236    fn compute_ntk_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
237        let n_samples_x = x.nrows();
238        let n_samples_y = y.nrows();
239
240        // Initialize kernel matrix with dot product kernel (layer 0)
241        let mut kernel = x.dot(&y.t());
242
243        // Normalize by input dimension for NTK parameterization
244        let d = x.ncols() as Float;
245        kernel.mapv_inplace(|k| k / d);
246
247        // Recursively compute kernels through layers
248        for _layer in 0..self.config.n_layers {
249            let mut new_kernel = Array2::zeros((n_samples_x, n_samples_y));
250
251            for i in 0..n_samples_x {
252                for j in 0..n_samples_y {
253                    let k_ij = kernel[[i, j]];
254                    let k_ii = if i < kernel.nrows() && i < kernel.ncols() {
255                        kernel[[i, i]]
256                    } else {
257                        1.0
258                    };
259                    let k_jj = if j < kernel.nrows() && j < kernel.ncols() {
260                        kernel[[j, j]]
261                    } else {
262                        1.0
263                    };
264
265                    // Compute correlation
266                    let norm = (k_ii * k_jj).sqrt().max(1e-10);
267                    let rho = (k_ij / norm).max(-1.0).min(1.0);
268
269                    // Apply activation kernel
270                    let activated = self.config.activation.kernel_value(rho);
271
272                    // Scale by variances
273                    new_kernel[[i, j]] =
274                        self.config.weight_variance * norm * activated + self.config.bias_variance;
275                }
276            }
277
278            kernel = new_kernel;
279        }
280
281        Ok(kernel)
282    }
283
284    /// Compute top k eigenvectors using power iteration
285    fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
286        let n = kernel.nrows();
287        let mut eigenvectors = Array2::zeros((n, k));
288        let mut kernel_deflated = kernel.clone();
289
290        let mut rng = thread_rng();
291        let normal = Normal::new(0.0, 1.0).unwrap();
292
293        for i in 0..k {
294            // Random initialization
295            let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal));
296
297            // Power iteration
298            for _iter in 0..50 {
299                v = kernel_deflated.dot(&v);
300                let norm = v.dot(&v).sqrt();
301                if norm > 1e-10 {
302                    v /= norm;
303                } else {
304                    break;
305                }
306            }
307
308            // Store eigenvector
309            for j in 0..n {
310                eigenvectors[[j, i]] = v[j];
311            }
312
313            // Deflate kernel
314            let lambda = v.dot(&kernel_deflated.dot(&v));
315            for row in 0..n {
316                for col in 0..n {
317                    kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
318                }
319            }
320        }
321
322        Ok(eigenvectors)
323    }
324}
325
326impl Estimator for NeuralTangentKernel<Untrained> {
327    type Config = NTKConfig;
328    type Error = SklearsError;
329    type Float = Float;
330
331    fn config(&self) -> &Self::Config {
332        &self.config
333    }
334}
335
336impl Fit<Array2<Float>, ()> for NeuralTangentKernel<Untrained> {
337    type Fitted = NeuralTangentKernel<Trained>;
338
339    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
340        if x.nrows() == 0 || x.ncols() == 0 {
341            return Err(SklearsError::InvalidInput(
342                "Input array cannot be empty".to_string(),
343            ));
344        }
345
346        // Store training data
347        let x_train = x.clone();
348
349        // For dimensionality reduction, compute eigendecomposition of kernel matrix
350        let kernel = self.compute_ntk_kernel(x, x)?;
351
352        // Use eigendecomposition for feature extraction
353        let n_components = x.nrows().min(self.n_components);
354
355        // Simple eigendecomposition using power iteration for top eigenvectors
356        let eigenvectors = if n_components < x.nrows() {
357            Some(self.compute_top_eigenvectors(&kernel, n_components)?)
358        } else {
359            None
360        };
361
362        Ok(NeuralTangentKernel {
363            config: self.config,
364            n_components: self.n_components,
365            x_train: Some(x_train),
366            eigenvectors,
367            _state: PhantomData,
368        })
369    }
370}
371
372impl Transform<Array2<Float>, Array2<Float>> for NeuralTangentKernel<Trained> {
373    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
374        let x_train = self.x_train.as_ref().unwrap();
375
376        if x.ncols() != x_train.ncols() {
377            return Err(SklearsError::InvalidInput(format!(
378                "Feature dimension mismatch: expected {}, got {}",
379                x_train.ncols(),
380                x.ncols()
381            )));
382        }
383
384        // Compute kernel between x and training data
385        let ntk = NeuralTangentKernel::<Untrained> {
386            config: self.config.clone(),
387            n_components: self.n_components,
388            x_train: None,
389            eigenvectors: None,
390            _state: PhantomData,
391        };
392        let kernel = ntk.compute_ntk_kernel(x, x_train)?;
393
394        // Project onto eigenvectors if available
395        if let Some(ref eigvecs) = self.eigenvectors {
396            Ok(kernel.dot(eigvecs))
397        } else {
398            Ok(kernel)
399        }
400    }
401}
402
403impl NeuralTangentKernel<Trained> {
404    /// Get the training data
405    pub fn x_train(&self) -> &Array2<Float> {
406        self.x_train.as_ref().unwrap()
407    }
408
409    /// Get the eigenvectors
410    pub fn eigenvectors(&self) -> Option<&Array2<Float>> {
411        self.eigenvectors.as_ref()
412    }
413}
414
415/// Deep Kernel Learning combines deep neural networks with kernel methods
416///
417/// This approach uses a deep neural network to learn a feature representation,
418/// then applies a kernel method in the learned feature space.
419///
420/// # Mathematical Background
421///
422/// Deep Kernel Learning defines a composite kernel:
423/// k(x, x') = k_base(φ(x; θ), φ(x'; θ))
424/// where:
425/// - φ(·; θ) is a deep neural network with parameters θ
426/// - k_base is a base kernel (e.g., RBF)
427///
428/// # Examples
429///
430/// ```rust,ignore
431/// use sklears_kernel_approximation::deep_learning_kernels::{DeepKernelLearning, DKLConfig};
432/// use scirs2_core::ndarray::array;
433/// use sklears_core::traits::{Fit, Transform};
434///
435/// let config = DKLConfig {
436///     feature_layers: vec![10, 20, 10],
437///     n_components: 50,
438///     ..Default::default()
439/// };
440///
441/// let dkl = DeepKernelLearning::new(config);
442/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
443/// let fitted = dkl.fit(&X, &()).unwrap();
444/// let features = fitted.transform(&X).unwrap();
445/// assert_eq!(features.shape(), &[3, 50]);
446/// ```
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct DKLConfig {
449    /// Sizes of feature extraction layers
450    pub feature_layers: Vec<usize>,
451    /// Number of random Fourier features for final kernel
452    pub n_components: usize,
453    /// Activation function for feature layers
454    pub activation: Activation,
455    /// Bandwidth for final RBF kernel
456    pub gamma: Float,
457    /// Learning rate for feature learning (currently not trainable, uses random features)
458    pub learning_rate: Float,
459}
460
461impl Default for DKLConfig {
462    fn default() -> Self {
463        Self {
464            feature_layers: vec![64, 32],
465            n_components: 100,
466            activation: Activation::ReLU,
467            gamma: 1.0,
468            learning_rate: 0.01,
469        }
470    }
471}
472
473#[derive(Debug, Clone)]
474pub struct DeepKernelLearning<State = Untrained> {
475    config: DKLConfig,
476
477    // Fitted attributes
478    layer_weights: Option<Vec<Array2<Float>>>,
479    layer_biases: Option<Vec<Array1<Float>>>,
480    random_weights: Option<Array2<Float>>,
481    random_offset: Option<Array1<Float>>,
482
483    _state: PhantomData<State>,
484}
485
486impl DeepKernelLearning<Untrained> {
487    /// Create new Deep Kernel Learning with configuration
488    pub fn new(config: DKLConfig) -> Self {
489        Self {
490            config,
491            layer_weights: None,
492            layer_biases: None,
493            random_weights: None,
494            random_offset: None,
495            _state: PhantomData,
496        }
497    }
498
499    /// Create with default configuration
500    pub fn with_components(n_components: usize) -> Self {
501        Self {
502            config: DKLConfig {
503                n_components,
504                ..Default::default()
505            },
506            layer_weights: None,
507            layer_biases: None,
508            random_weights: None,
509            random_offset: None,
510            _state: PhantomData,
511        }
512    }
513
514    /// Set activation function
515    pub fn activation(mut self, activation: Activation) -> Self {
516        self.config.activation = activation;
517        self
518    }
519
520    /// Set gamma (bandwidth) for final RBF kernel
521    pub fn gamma(mut self, gamma: Float) -> Self {
522        self.config.gamma = gamma;
523        self
524    }
525}
526
527impl Estimator for DeepKernelLearning<Untrained> {
528    type Config = DKLConfig;
529    type Error = SklearsError;
530    type Float = Float;
531
532    fn config(&self) -> &Self::Config {
533        &self.config
534    }
535}
536
537impl Fit<Array2<Float>, ()> for DeepKernelLearning<Untrained> {
538    type Fitted = DeepKernelLearning<Trained>;
539
540    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
541        if x.nrows() == 0 || x.ncols() == 0 {
542            return Err(SklearsError::InvalidInput(
543                "Input array cannot be empty".to_string(),
544            ));
545        }
546
547        let mut rng = thread_rng();
548        let normal_dist = Normal::new(0.0, 1.0).unwrap();
549
550        // Initialize feature extraction layers with random weights
551        let mut layer_weights = Vec::new();
552        let mut layer_biases = Vec::new();
553
554        let mut in_features = x.ncols();
555        for &out_features in &self.config.feature_layers {
556            // Xavier/Glorot initialization
557            let scale = (2.0 / (in_features + out_features) as Float).sqrt();
558
559            let weights = Array2::from_shape_fn((in_features, out_features), |_| {
560                rng.sample(normal_dist) * scale
561            });
562
563            let biases = Array1::from_shape_fn(out_features, |_| rng.sample(normal_dist) * 0.01);
564
565            layer_weights.push(weights);
566            layer_biases.push(biases);
567            in_features = out_features;
568        }
569
570        // Initialize random Fourier features for final kernel
571        let final_features = if self.config.feature_layers.is_empty() {
572            x.ncols()
573        } else {
574            *self.config.feature_layers.last().unwrap()
575        };
576
577        let random_weights =
578            Array2::from_shape_fn((final_features, self.config.n_components), |_| {
579                rng.sample(normal_dist) * (2.0 * self.config.gamma).sqrt()
580            });
581
582        let uniform_dist = Uniform::new(0.0, 2.0 * PI).unwrap();
583        let random_offset =
584            Array1::from_shape_fn(self.config.n_components, |_| rng.sample(uniform_dist));
585
586        Ok(DeepKernelLearning {
587            config: self.config,
588            layer_weights: Some(layer_weights),
589            layer_biases: Some(layer_biases),
590            random_weights: Some(random_weights),
591            random_offset: Some(random_offset),
592            _state: PhantomData,
593        })
594    }
595}
596
597impl DeepKernelLearning<Trained> {
598    /// Apply feature extraction layers
599    fn extract_features(&self, x: &Array2<Float>) -> Array2<Float> {
600        let mut features = x.clone();
601        let layer_weights = self.layer_weights.as_ref().unwrap();
602        let layer_biases = self.layer_biases.as_ref().unwrap();
603
604        for (weights, biases) in layer_weights.iter().zip(layer_biases.iter()) {
605            // Linear transformation
606            features = features.dot(weights);
607
608            // Add bias
609            for i in 0..features.nrows() {
610                for j in 0..features.ncols() {
611                    features[[i, j]] += biases[j];
612                }
613            }
614
615            // Apply activation
616            features.mapv_inplace(|v| self.config.activation.apply(v));
617        }
618
619        features
620    }
621}
622
623impl Transform<Array2<Float>, Array2<Float>> for DeepKernelLearning<Trained> {
624    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
625        // Extract deep features
626        let deep_features = self.extract_features(x);
627
628        // Apply random Fourier features to deep features
629        let random_weights = self.random_weights.as_ref().unwrap();
630        let random_offset = self.random_offset.as_ref().unwrap();
631
632        let projection = deep_features.dot(random_weights);
633
634        let n_samples = x.nrows();
635        let mut output = Array2::zeros((n_samples, self.config.n_components));
636
637        let normalizer = (2.0 / self.config.n_components as Float).sqrt();
638        for i in 0..n_samples {
639            for j in 0..self.config.n_components {
640                output[[i, j]] = normalizer * (projection[[i, j]] + random_offset[j]).cos();
641            }
642        }
643
644        Ok(output)
645    }
646}
647
648impl DeepKernelLearning<Trained> {
649    /// Get the layer weights
650    pub fn layer_weights(&self) -> &Vec<Array2<Float>> {
651        self.layer_weights.as_ref().unwrap()
652    }
653
654    /// Get the layer biases
655    pub fn layer_biases(&self) -> &Vec<Array1<Float>> {
656        self.layer_biases.as_ref().unwrap()
657    }
658
659    /// Get the random weights
660    pub fn random_weights(&self) -> &Array2<Float> {
661        self.random_weights.as_ref().unwrap()
662    }
663
664    /// Get the random offset
665    pub fn random_offset(&self) -> &Array1<Float> {
666        self.random_offset.as_ref().unwrap()
667    }
668}
669
670/// Infinite-Width Network Kernel
671///
672/// This implements the kernel corresponding to an infinitely-wide neural network,
673/// also known as the Neural Network Gaussian Process (NNGP).
674///
675/// # Mathematical Background
676///
677/// For a neural network f(x; θ) with random weights θ ~ N(0, σ²/n_h):
678/// As the hidden layer width n_h → ∞, f(x; θ) → GP(0, K)
679/// where K is the NNGP kernel determined by the architecture and activation.
680///
681/// # Examples
682///
683/// ```rust,ignore
684/// use sklears_kernel_approximation::deep_learning_kernels::{InfiniteWidthKernel, Activation};
685/// use scirs2_core::ndarray::array;
686/// use sklears_core::traits::{Fit, Transform};
687///
688/// let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
689/// let X = array![[1.0, 2.0], [3.0, 4.0]];
690/// let fitted = kernel.fit(&X, &()).unwrap();
691/// let features = fitted.transform(&X).unwrap();
692/// assert_eq!(features.shape()[0], 2);
693/// ```
694#[derive(Debug, Clone)]
695pub struct InfiniteWidthKernel<State = Untrained> {
696    n_layers: usize,
697    activation: Activation,
698    n_components: usize,
699
700    // Fitted attributes
701    x_train: Option<Array2<Float>>,
702    eigenvectors: Option<Array2<Float>>,
703
704    _state: PhantomData<State>,
705}
706
707impl InfiniteWidthKernel<Untrained> {
708    /// Create new infinite-width kernel
709    pub fn new(n_layers: usize, activation: Activation) -> Self {
710        Self {
711            n_layers,
712            activation,
713            n_components: 100,
714            x_train: None,
715            eigenvectors: None,
716            _state: PhantomData,
717        }
718    }
719
720    /// Set number of components for dimensionality reduction
721    pub fn n_components(mut self, n: usize) -> Self {
722        self.n_components = n;
723        self
724    }
725
726    /// Compute the NNGP kernel matrix
727    fn compute_nngp_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
728        let n_x = x.nrows();
729        let n_y = y.nrows();
730        let d = x.ncols() as Float;
731
732        // Initialize with normalized dot product
733        let mut kernel = x.dot(&y.t());
734        kernel.mapv_inplace(|k| k / d);
735
736        // Recursively apply activation kernels
737        for _ in 0..self.n_layers {
738            let mut new_kernel = Array2::zeros((n_x, n_y));
739
740            for i in 0..n_x {
741                for j in 0..n_y {
742                    let k_ij = kernel[[i, j]];
743                    let k_ii = if i < n_y { kernel[[i, i]] } else { 1.0 };
744                    let k_jj = if j < n_x { kernel[[j, j]] } else { 1.0 };
745
746                    let norm = (k_ii * k_jj).sqrt().max(1e-10);
747                    let rho = (k_ij / norm).max(-1.0).min(1.0);
748
749                    new_kernel[[i, j]] = norm * self.activation.kernel_value(rho);
750                }
751            }
752
753            kernel = new_kernel;
754        }
755
756        kernel
757    }
758
759    fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
760        let n = kernel.nrows();
761        let mut eigenvectors = Array2::zeros((n, k));
762        let mut kernel_deflated = kernel.clone();
763
764        let mut rng = thread_rng();
765        let normal_dist = Normal::new(0.0, 1.0).unwrap();
766
767        for i in 0..k {
768            let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal_dist));
769
770            // Power iteration
771            for _iter in 0..50 {
772                v = kernel_deflated.dot(&v);
773                let norm = v.dot(&v).sqrt();
774                if norm > 1e-10 {
775                    v /= norm;
776                } else {
777                    break;
778                }
779            }
780
781            for j in 0..n {
782                eigenvectors[[j, i]] = v[j];
783            }
784
785            let lambda = v.dot(&kernel_deflated.dot(&v));
786            for row in 0..n {
787                for col in 0..n {
788                    kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
789                }
790            }
791        }
792
793        Ok(eigenvectors)
794    }
795}
796
797impl Estimator for InfiniteWidthKernel<Untrained> {
798    type Config = ();
799    type Error = SklearsError;
800    type Float = Float;
801
802    fn config(&self) -> &Self::Config {
803        &()
804    }
805}
806
807impl Fit<Array2<Float>, ()> for InfiniteWidthKernel<Untrained> {
808    type Fitted = InfiniteWidthKernel<Trained>;
809
810    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
811        if x.nrows() == 0 || x.ncols() == 0 {
812            return Err(SklearsError::InvalidInput(
813                "Input array cannot be empty".to_string(),
814            ));
815        }
816
817        let x_train = x.clone();
818        let kernel = self.compute_nngp_kernel(x, x);
819
820        // Compute eigenvectors using power iteration
821        let n_components = self.n_components.min(x.nrows());
822        let eigenvectors = self.compute_top_eigenvectors(&kernel, n_components)?;
823
824        Ok(InfiniteWidthKernel {
825            n_layers: self.n_layers,
826            activation: self.activation,
827            n_components: self.n_components,
828            x_train: Some(x_train),
829            eigenvectors: Some(eigenvectors),
830            _state: PhantomData,
831        })
832    }
833}
834
835impl Transform<Array2<Float>, Array2<Float>> for InfiniteWidthKernel<Trained> {
836    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
837        let x_train = self.x_train.as_ref().unwrap();
838        let eigenvectors = self.eigenvectors.as_ref().unwrap();
839
840        if x.ncols() != x_train.ncols() {
841            return Err(SklearsError::InvalidInput(format!(
842                "Feature dimension mismatch: expected {}, got {}",
843                x_train.ncols(),
844                x.ncols()
845            )));
846        }
847
848        let kernel_obj = InfiniteWidthKernel::<Untrained> {
849            n_layers: self.n_layers,
850            activation: self.activation,
851            n_components: self.n_components,
852            x_train: None,
853            eigenvectors: None,
854            _state: PhantomData,
855        };
856
857        let kernel = kernel_obj.compute_nngp_kernel(x, x_train);
858        Ok(kernel.dot(eigenvectors))
859    }
860}
861
862impl InfiniteWidthKernel<Trained> {
863    /// Get the training data
864    pub fn x_train(&self) -> &Array2<Float> {
865        self.x_train.as_ref().unwrap()
866    }
867
868    /// Get the eigenvectors
869    pub fn eigenvectors(&self) -> &Array2<Float> {
870        self.eigenvectors.as_ref().unwrap()
871    }
872}
873
874#[cfg(test)]
875mod tests {
876    use super::*;
877    use scirs2_core::ndarray::array;
878
879    #[test]
880    fn test_activation_functions() {
881        let activations = vec![
882            Activation::ReLU,
883            Activation::Tanh,
884            Activation::Sigmoid,
885            Activation::Linear,
886            Activation::GELU,
887            Activation::Swish,
888            Activation::Erf,
889        ];
890
891        for act in activations {
892            let val = act.apply(0.5);
893            assert!(val.is_finite());
894
895            let kernel_val = act.kernel_value(0.5);
896            assert!(kernel_val.is_finite());
897        }
898    }
899
900    #[test]
901    fn test_neural_tangent_kernel_basic() {
902        let config = NTKConfig {
903            n_layers: 2,
904            hidden_width: Some(512),
905            activation: Activation::ReLU,
906            infinite_width: true,
907            weight_variance: 1.0,
908            bias_variance: 0.1,
909        };
910
911        let ntk = NeuralTangentKernel::new(config);
912        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
913
914        let fitted = ntk.fit(&x, &()).unwrap();
915        let features = fitted.transform(&x).unwrap();
916
917        assert_eq!(features.nrows(), 3);
918        assert!(features.ncols() > 0);
919    }
920
921    #[test]
922    fn test_deep_kernel_learning() {
923        let config = DKLConfig {
924            feature_layers: vec![10, 20],
925            n_components: 50,
926            activation: Activation::ReLU,
927            gamma: 1.0,
928            learning_rate: 0.01,
929        };
930
931        let dkl = DeepKernelLearning::new(config);
932        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
933
934        let fitted = dkl.fit(&x, &()).unwrap();
935        let features = fitted.transform(&x).unwrap();
936
937        assert_eq!(features.shape(), &[3, 50]);
938    }
939
940    #[test]
941    fn test_infinite_width_kernel() {
942        let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
943        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
944
945        let fitted = kernel.fit(&x, &()).unwrap();
946        let features = fitted.transform(&x).unwrap();
947
948        assert_eq!(features.nrows(), 4);
949        assert!(features.ncols() > 0);
950    }
951
952    #[test]
953    fn test_ntk_different_activations() {
954        let activations = vec![Activation::ReLU, Activation::Tanh, Activation::GELU];
955        let x = array![[1.0, 2.0], [3.0, 4.0]];
956
957        for act in activations {
958            let ntk = NeuralTangentKernel::with_layers(2).activation(act);
959            let fitted = ntk.fit(&x, &()).unwrap();
960            let features = fitted.transform(&x).unwrap();
961
962            assert_eq!(features.nrows(), 2);
963        }
964    }
965
966    #[test]
967    fn test_dkl_feature_extraction() {
968        let config = DKLConfig {
969            feature_layers: vec![8, 4],
970            n_components: 20,
971            activation: Activation::Tanh,
972            gamma: 0.5,
973            learning_rate: 0.01,
974        };
975
976        let dkl = DeepKernelLearning::new(config);
977        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
978
979        let fitted = dkl.fit(&x, &()).unwrap();
980
981        // Test that features have correct shape
982        let features = fitted.transform(&x).unwrap();
983        assert_eq!(features.shape(), &[2, 20]);
984
985        // Test that all features are finite
986        for val in features.iter() {
987            assert!(val.is_finite());
988        }
989    }
990
991    #[test]
992    fn test_empty_input_error() {
993        let ntk = NeuralTangentKernel::with_layers(2);
994        let x_empty: Array2<Float> = Array2::zeros((0, 0));
995
996        assert!(ntk.fit(&x_empty, &()).is_err());
997    }
998
999    #[test]
1000    fn test_dimension_mismatch_error() {
1001        let ntk = NeuralTangentKernel::with_layers(2);
1002        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
1003        let x_test = array![[1.0, 2.0, 3.0]];
1004
1005        let fitted = ntk.fit(&x_train, &()).unwrap();
1006        assert!(fitted.transform(&x_test).is_err());
1007    }
1008}