Skip to main content

sklears_kernel_approximation/
deep_learning_kernels.rs

1//! Deep Learning Integration for Kernel Approximation
2//!
3//! This module implements advanced kernel methods inspired by deep learning,
4//! including Neural Tangent Kernels (NTK), Deep Kernel Learning, and
5//! infinite-width network approximations.
6//!
7//! # References
8//! - Jacot et al. (2018): "Neural Tangent Kernel: Convergence and Generalization in Neural Networks"
9//! - Wilson et al. (2016): "Deep Kernel Learning"
10//! - Lee et al. (2018): "Deep Neural Networks as Gaussian Processes"
11//! - Arora et al. (2019): "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks"
12
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::{Normal, Uniform};
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18    error::{Result, SklearsError},
19    prelude::{Fit, Transform},
20    traits::{Estimator, Trained, Untrained},
21    types::Float,
22};
23use std::marker::PhantomData;
24
25const PI: Float = std::f64::consts::PI;
26
27/// Activation functions for neural network kernels
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum Activation {
30    /// ReLU activation: max(0, x)
31    ReLU,
32    /// Tanh activation
33    Tanh,
34    /// Sigmoid activation: 1 / (1 + exp(-x))
35    Sigmoid,
36    /// Error function (erf)
37    Erf,
38    /// Linear activation (identity)
39    Linear,
40    /// GELU activation
41    GELU,
42    /// Swish activation: x * sigmoid(x)
43    Swish,
44}
45
46impl Activation {
47    /// Apply activation function element-wise
48    pub fn apply(&self, x: Float) -> Float {
49        match self {
50            Activation::ReLU => x.max(0.0),
51            Activation::Tanh => x.tanh(),
52            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
53            Activation::Erf => {
54                // Abramowitz and Stegun approximation for erf(x)
55                let sign = if x >= 0.0 { 1.0 } else { -1.0 };
56                let x_abs = x.abs();
57                let t = 1.0 / (1.0 + 0.3275911 * x_abs);
58                let approx = 1.0
59                    - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736)
60                        * t
61                        + 0.254829592)
62                        * t
63                        * (-x_abs * x_abs).exp();
64                sign * approx
65            }
66            Activation::Linear => x,
67            Activation::GELU => {
68                // GELU(x) = x * Φ(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
69                let sqrt_2_over_pi = (2.0 / PI).sqrt();
70                0.5 * x * (1.0 + (sqrt_2_over_pi * (x + 0.044715 * x.powi(3))).tanh())
71            }
72            Activation::Swish => {
73                let sigmoid = 1.0 / (1.0 + (-x).exp());
74                x * sigmoid
75            }
76        }
77    }
78
79    /// Compute the kernel for this activation at correlation rho
80    /// This is used in NTK computation
81    pub fn kernel_value(&self, rho: Float) -> Float {
82        match self {
83            Activation::ReLU => {
84                // K(x, x') = ||x|| ||x'|| / (2π) * (sin(θ) + (π - θ) cos(θ))
85                // where cos(θ) = rho / (||x|| ||x'||)
86                let theta = rho.max(-1.0).min(1.0).acos();
87                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI)
88            }
89            Activation::Tanh => {
90                // For tanh, use approximation from literature
91                2.0 / PI * (rho * (1.0 + rho.powi(2)).sqrt()).asin()
92            }
93            Activation::Erf => {
94                // For erf activation
95                2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2)).sqrt())).asin()
96            }
97            Activation::Linear => rho,
98            Activation::Sigmoid => {
99                // Approximation for sigmoid kernel
100                2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2).abs()).sqrt())).asin()
101            }
102            Activation::GELU => {
103                // GELU kernel approximation
104                let theta = rho.max(-1.0).min(1.0).acos();
105                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.702
106            }
107            Activation::Swish => {
108                // Swish kernel approximation (similar to GELU)
109                let theta = rho.max(-1.0).min(1.0).acos();
110                (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.5
111            }
112        }
113    }
114}
115
116/// Configuration for Neural Tangent Kernel
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct NTKConfig {
119    /// Number of layers in the neural network
120    pub n_layers: usize,
121    /// Width of hidden layers (for finite-width approximation)
122    pub hidden_width: Option<usize>,
123    /// Activation function
124    pub activation: Activation,
125    /// Whether to use the infinite-width limit
126    pub infinite_width: bool,
127    /// Variance of weight initialization
128    pub weight_variance: Float,
129    /// Variance of bias initialization
130    pub bias_variance: Float,
131}
132
133impl Default for NTKConfig {
134    fn default() -> Self {
135        Self {
136            n_layers: 3,
137            hidden_width: Some(1024),
138            activation: Activation::ReLU,
139            infinite_width: true,
140            weight_variance: 1.0,
141            bias_variance: 1.0,
142        }
143    }
144}
145
146/// Neural Tangent Kernel (NTK) Approximation
147///
148/// The Neural Tangent Kernel describes the evolution of an infinitely-wide neural network
149/// during gradient descent. This implementation provides both exact infinite-width
150/// computation and finite-width approximations.
151///
152/// # Mathematical Background
153///
154/// For a fully-connected neural network with L layers and activation σ:
155/// - The NTK is defined as: Θ(x, x') = E[∂f/∂θ(x) · ∂f/∂θ(x')]
156/// - In the infinite-width limit, the NTK remains constant during training
157/// - The network evolution follows: df/dt = Θ(X, X) ∇f L(f)
158///
159/// # Examples
160///
161/// ```rust,ignore
162/// use sklears_kernel_approximation::deep_learning_kernels::{NeuralTangentKernel, NTKConfig, Activation};
163/// use scirs2_core::ndarray::array;
164/// use sklears_core::traits::{Fit, Transform};
165///
166/// let config = NTKConfig {
167///     n_layers: 3,
168///     activation: Activation::ReLU,
169///     infinite_width: true,
170///     ..Default::default()
171///  };
172///
173/// let ntk = NeuralTangentKernel::new(config);
174/// let X = array![[1.0, 2.0], [3.0, 4.0]];
175/// let fitted = ntk.fit(&X, &()).unwrap();
176/// let features = fitted.transform(&X).unwrap();
177/// assert_eq!(features.shape()[0], 2);
178/// ```
179#[derive(Debug, Clone)]
180pub struct NeuralTangentKernel<State = Untrained> {
181    config: NTKConfig,
182    n_components: usize,
183
184    // Fitted attributes
185    x_train: Option<Array2<Float>>,
186    eigenvectors: Option<Array2<Float>>,
187
188    _state: PhantomData<State>,
189}
190
191impl NeuralTangentKernel<Untrained> {
192    /// Create a new Neural Tangent Kernel with the given configuration
193    pub fn new(config: NTKConfig) -> Self {
194        Self {
195            config,
196            n_components: 100,
197            x_train: None,
198            eigenvectors: None,
199            _state: PhantomData,
200        }
201    }
202
203    /// Create a new NTK with default configuration
204    pub fn with_layers(n_layers: usize) -> Self {
205        Self {
206            config: NTKConfig {
207                n_layers,
208                ..Default::default()
209            },
210            n_components: 100,
211            x_train: None,
212            eigenvectors: None,
213            _state: PhantomData,
214        }
215    }
216
217    /// Set activation function
218    pub fn activation(mut self, activation: Activation) -> Self {
219        self.config.activation = activation;
220        self
221    }
222
223    /// Set whether to use infinite-width limit
224    pub fn infinite_width(mut self, infinite: bool) -> Self {
225        self.config.infinite_width = infinite;
226        self
227    }
228
229    /// Set number of components
230    pub fn n_components(mut self, n: usize) -> Self {
231        self.n_components = n;
232        self
233    }
234
235    /// Compute the NTK kernel matrix between two sets of points
236    fn compute_ntk_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
237        let n_samples_x = x.nrows();
238        let n_samples_y = y.nrows();
239
240        // Initialize kernel matrix with dot product kernel (layer 0)
241        let mut kernel = x.dot(&y.t());
242
243        // Normalize by input dimension for NTK parameterization
244        let d = x.ncols() as Float;
245        kernel.mapv_inplace(|k| k / d);
246
247        // Recursively compute kernels through layers
248        for _layer in 0..self.config.n_layers {
249            let mut new_kernel = Array2::zeros((n_samples_x, n_samples_y));
250
251            for i in 0..n_samples_x {
252                for j in 0..n_samples_y {
253                    let k_ij = kernel[[i, j]];
254                    let k_ii = if i < kernel.nrows() && i < kernel.ncols() {
255                        kernel[[i, i]]
256                    } else {
257                        1.0
258                    };
259                    let k_jj = if j < kernel.nrows() && j < kernel.ncols() {
260                        kernel[[j, j]]
261                    } else {
262                        1.0
263                    };
264
265                    // Compute correlation
266                    let norm = (k_ii * k_jj).sqrt().max(1e-10);
267                    let rho = (k_ij / norm).max(-1.0).min(1.0);
268
269                    // Apply activation kernel
270                    let activated = self.config.activation.kernel_value(rho);
271
272                    // Scale by variances
273                    new_kernel[[i, j]] =
274                        self.config.weight_variance * norm * activated + self.config.bias_variance;
275                }
276            }
277
278            kernel = new_kernel;
279        }
280
281        Ok(kernel)
282    }
283
284    /// Compute top k eigenvectors using power iteration
285    fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
286        let n = kernel.nrows();
287        let mut eigenvectors = Array2::zeros((n, k));
288        let mut kernel_deflated = kernel.clone();
289
290        let mut rng = thread_rng();
291        let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
292
293        for i in 0..k {
294            // Random initialization
295            let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal));
296
297            // Power iteration
298            for _iter in 0..50 {
299                v = kernel_deflated.dot(&v);
300                let norm = v.dot(&v).sqrt();
301                if norm > 1e-10 {
302                    v /= norm;
303                } else {
304                    break;
305                }
306            }
307
308            // Store eigenvector
309            for j in 0..n {
310                eigenvectors[[j, i]] = v[j];
311            }
312
313            // Deflate kernel
314            let lambda = v.dot(&kernel_deflated.dot(&v));
315            for row in 0..n {
316                for col in 0..n {
317                    kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
318                }
319            }
320        }
321
322        Ok(eigenvectors)
323    }
324}
325
326impl Estimator for NeuralTangentKernel<Untrained> {
327    type Config = NTKConfig;
328    type Error = SklearsError;
329    type Float = Float;
330
331    fn config(&self) -> &Self::Config {
332        &self.config
333    }
334}
335
336impl Fit<Array2<Float>, ()> for NeuralTangentKernel<Untrained> {
337    type Fitted = NeuralTangentKernel<Trained>;
338
339    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
340        if x.nrows() == 0 || x.ncols() == 0 {
341            return Err(SklearsError::InvalidInput(
342                "Input array cannot be empty".to_string(),
343            ));
344        }
345
346        // Store training data
347        let x_train = x.clone();
348
349        // For dimensionality reduction, compute eigendecomposition of kernel matrix
350        let kernel = self.compute_ntk_kernel(x, x)?;
351
352        // Use eigendecomposition for feature extraction
353        let n_components = x.nrows().min(self.n_components);
354
355        // Simple eigendecomposition using power iteration for top eigenvectors
356        let eigenvectors = if n_components < x.nrows() {
357            Some(self.compute_top_eigenvectors(&kernel, n_components)?)
358        } else {
359            None
360        };
361
362        Ok(NeuralTangentKernel {
363            config: self.config,
364            n_components: self.n_components,
365            x_train: Some(x_train),
366            eigenvectors,
367            _state: PhantomData,
368        })
369    }
370}
371
372impl Transform<Array2<Float>, Array2<Float>> for NeuralTangentKernel<Trained> {
373    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
374        let x_train = self.x_train.as_ref().expect("operation should succeed");
375
376        if x.ncols() != x_train.ncols() {
377            return Err(SklearsError::InvalidInput(format!(
378                "Feature dimension mismatch: expected {}, got {}",
379                x_train.ncols(),
380                x.ncols()
381            )));
382        }
383
384        // Compute kernel between x and training data
385        let ntk = NeuralTangentKernel::<Untrained> {
386            config: self.config.clone(),
387            n_components: self.n_components,
388            x_train: None,
389            eigenvectors: None,
390            _state: PhantomData,
391        };
392        let kernel = ntk.compute_ntk_kernel(x, x_train)?;
393
394        // Project onto eigenvectors if available
395        if let Some(ref eigvecs) = self.eigenvectors {
396            Ok(kernel.dot(eigvecs))
397        } else {
398            Ok(kernel)
399        }
400    }
401}
402
403impl NeuralTangentKernel<Trained> {
404    /// Get the training data
405    pub fn x_train(&self) -> &Array2<Float> {
406        self.x_train.as_ref().expect("operation should succeed")
407    }
408
409    /// Get the eigenvectors
410    pub fn eigenvectors(&self) -> Option<&Array2<Float>> {
411        self.eigenvectors.as_ref()
412    }
413}
414
415/// Deep Kernel Learning combines deep neural networks with kernel methods
416///
417/// This approach uses a deep neural network to learn a feature representation,
418/// then applies a kernel method in the learned feature space.
419///
420/// # Mathematical Background
421///
422/// Deep Kernel Learning defines a composite kernel:
423/// k(x, x') = k_base(φ(x; θ), φ(x'; θ))
424/// where:
425/// - φ(·; θ) is a deep neural network with parameters θ
426/// - k_base is a base kernel (e.g., RBF)
427///
428/// # Examples
429///
430/// ```rust,ignore
431/// use sklears_kernel_approximation::deep_learning_kernels::{DeepKernelLearning, DKLConfig};
432/// use scirs2_core::ndarray::array;
433/// use sklears_core::traits::{Fit, Transform};
434///
435/// let config = DKLConfig {
436///     feature_layers: vec![10, 20, 10],
437///     n_components: 50,
438///     ..Default::default()
439/// };
440///
441/// let dkl = DeepKernelLearning::new(config);
442/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
443/// let fitted = dkl.fit(&X, &()).unwrap();
444/// let features = fitted.transform(&X).unwrap();
445/// assert_eq!(features.shape(), &[3, 50]);
446/// ```
447#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct DKLConfig {
449    /// Sizes of feature extraction layers
450    pub feature_layers: Vec<usize>,
451    /// Number of random Fourier features for final kernel
452    pub n_components: usize,
453    /// Activation function for feature layers
454    pub activation: Activation,
455    /// Bandwidth for final RBF kernel
456    pub gamma: Float,
457    /// Learning rate for feature learning (currently not trainable, uses random features)
458    pub learning_rate: Float,
459}
460
461impl Default for DKLConfig {
462    fn default() -> Self {
463        Self {
464            feature_layers: vec![64, 32],
465            n_components: 100,
466            activation: Activation::ReLU,
467            gamma: 1.0,
468            learning_rate: 0.01,
469        }
470    }
471}
472
473#[derive(Debug, Clone)]
474pub struct DeepKernelLearning<State = Untrained> {
475    config: DKLConfig,
476
477    // Fitted attributes
478    layer_weights: Option<Vec<Array2<Float>>>,
479    layer_biases: Option<Vec<Array1<Float>>>,
480    random_weights: Option<Array2<Float>>,
481    random_offset: Option<Array1<Float>>,
482
483    _state: PhantomData<State>,
484}
485
486impl DeepKernelLearning<Untrained> {
487    /// Create new Deep Kernel Learning with configuration
488    pub fn new(config: DKLConfig) -> Self {
489        Self {
490            config,
491            layer_weights: None,
492            layer_biases: None,
493            random_weights: None,
494            random_offset: None,
495            _state: PhantomData,
496        }
497    }
498
499    /// Create with default configuration
500    pub fn with_components(n_components: usize) -> Self {
501        Self {
502            config: DKLConfig {
503                n_components,
504                ..Default::default()
505            },
506            layer_weights: None,
507            layer_biases: None,
508            random_weights: None,
509            random_offset: None,
510            _state: PhantomData,
511        }
512    }
513
514    /// Set activation function
515    pub fn activation(mut self, activation: Activation) -> Self {
516        self.config.activation = activation;
517        self
518    }
519
520    /// Set gamma (bandwidth) for final RBF kernel
521    pub fn gamma(mut self, gamma: Float) -> Self {
522        self.config.gamma = gamma;
523        self
524    }
525}
526
527impl Estimator for DeepKernelLearning<Untrained> {
528    type Config = DKLConfig;
529    type Error = SklearsError;
530    type Float = Float;
531
532    fn config(&self) -> &Self::Config {
533        &self.config
534    }
535}
536
537impl Fit<Array2<Float>, ()> for DeepKernelLearning<Untrained> {
538    type Fitted = DeepKernelLearning<Trained>;
539
540    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
541        if x.nrows() == 0 || x.ncols() == 0 {
542            return Err(SklearsError::InvalidInput(
543                "Input array cannot be empty".to_string(),
544            ));
545        }
546
547        let mut rng = thread_rng();
548        let normal_dist = Normal::new(0.0, 1.0).expect("operation should succeed");
549
550        // Initialize feature extraction layers with random weights
551        let mut layer_weights = Vec::new();
552        let mut layer_biases = Vec::new();
553
554        let mut in_features = x.ncols();
555        for &out_features in &self.config.feature_layers {
556            // Xavier/Glorot initialization
557            let scale = (2.0 / (in_features + out_features) as Float).sqrt();
558
559            let weights = Array2::from_shape_fn((in_features, out_features), |_| {
560                rng.sample(normal_dist) * scale
561            });
562
563            let biases = Array1::from_shape_fn(out_features, |_| rng.sample(normal_dist) * 0.01);
564
565            layer_weights.push(weights);
566            layer_biases.push(biases);
567            in_features = out_features;
568        }
569
570        // Initialize random Fourier features for final kernel
571        let final_features = if self.config.feature_layers.is_empty() {
572            x.ncols()
573        } else {
574            *self
575                .config
576                .feature_layers
577                .last()
578                .expect("operation should succeed")
579        };
580
581        let random_weights =
582            Array2::from_shape_fn((final_features, self.config.n_components), |_| {
583                rng.sample(normal_dist) * (2.0 * self.config.gamma).sqrt()
584            });
585
586        let uniform_dist = Uniform::new(0.0, 2.0 * PI).expect("operation should succeed");
587        let random_offset =
588            Array1::from_shape_fn(self.config.n_components, |_| rng.sample(uniform_dist));
589
590        Ok(DeepKernelLearning {
591            config: self.config,
592            layer_weights: Some(layer_weights),
593            layer_biases: Some(layer_biases),
594            random_weights: Some(random_weights),
595            random_offset: Some(random_offset),
596            _state: PhantomData,
597        })
598    }
599}
600
601impl DeepKernelLearning<Trained> {
602    /// Apply feature extraction layers
603    fn extract_features(&self, x: &Array2<Float>) -> Array2<Float> {
604        let mut features = x.clone();
605        let layer_weights = self
606            .layer_weights
607            .as_ref()
608            .expect("operation should succeed");
609        let layer_biases = self
610            .layer_biases
611            .as_ref()
612            .expect("operation should succeed");
613
614        for (weights, biases) in layer_weights.iter().zip(layer_biases.iter()) {
615            // Linear transformation
616            features = features.dot(weights);
617
618            // Add bias
619            for i in 0..features.nrows() {
620                for j in 0..features.ncols() {
621                    features[[i, j]] += biases[j];
622                }
623            }
624
625            // Apply activation
626            features.mapv_inplace(|v| self.config.activation.apply(v));
627        }
628
629        features
630    }
631}
632
633impl Transform<Array2<Float>, Array2<Float>> for DeepKernelLearning<Trained> {
634    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
635        // Extract deep features
636        let deep_features = self.extract_features(x);
637
638        // Apply random Fourier features to deep features
639        let random_weights = self
640            .random_weights
641            .as_ref()
642            .expect("operation should succeed");
643        let random_offset = self
644            .random_offset
645            .as_ref()
646            .expect("operation should succeed");
647
648        let projection = deep_features.dot(random_weights);
649
650        let n_samples = x.nrows();
651        let mut output = Array2::zeros((n_samples, self.config.n_components));
652
653        let normalizer = (2.0 / self.config.n_components as Float).sqrt();
654        for i in 0..n_samples {
655            for j in 0..self.config.n_components {
656                output[[i, j]] = normalizer * (projection[[i, j]] + random_offset[j]).cos();
657            }
658        }
659
660        Ok(output)
661    }
662}
663
664impl DeepKernelLearning<Trained> {
665    /// Get the layer weights
666    pub fn layer_weights(&self) -> &Vec<Array2<Float>> {
667        self.layer_weights
668            .as_ref()
669            .expect("operation should succeed")
670    }
671
672    /// Get the layer biases
673    pub fn layer_biases(&self) -> &Vec<Array1<Float>> {
674        self.layer_biases
675            .as_ref()
676            .expect("operation should succeed")
677    }
678
679    /// Get the random weights
680    pub fn random_weights(&self) -> &Array2<Float> {
681        self.random_weights
682            .as_ref()
683            .expect("operation should succeed")
684    }
685
686    /// Get the random offset
687    pub fn random_offset(&self) -> &Array1<Float> {
688        self.random_offset
689            .as_ref()
690            .expect("operation should succeed")
691    }
692}
693
694/// Infinite-Width Network Kernel
695///
696/// This implements the kernel corresponding to an infinitely-wide neural network,
697/// also known as the Neural Network Gaussian Process (NNGP).
698///
699/// # Mathematical Background
700///
701/// For a neural network f(x; θ) with random weights θ ~ N(0, σ²/n_h):
702/// As the hidden layer width n_h → ∞, f(x; θ) → GP(0, K)
703/// where K is the NNGP kernel determined by the architecture and activation.
704///
705/// # Examples
706///
707/// ```rust,ignore
708/// use sklears_kernel_approximation::deep_learning_kernels::{InfiniteWidthKernel, Activation};
709/// use scirs2_core::ndarray::array;
710/// use sklears_core::traits::{Fit, Transform};
711///
712/// let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
713/// let X = array![[1.0, 2.0], [3.0, 4.0]];
714/// let fitted = kernel.fit(&X, &()).unwrap();
715/// let features = fitted.transform(&X).unwrap();
716/// assert_eq!(features.shape()[0], 2);
717/// ```
718#[derive(Debug, Clone)]
719pub struct InfiniteWidthKernel<State = Untrained> {
720    n_layers: usize,
721    activation: Activation,
722    n_components: usize,
723
724    // Fitted attributes
725    x_train: Option<Array2<Float>>,
726    eigenvectors: Option<Array2<Float>>,
727
728    _state: PhantomData<State>,
729}
730
731impl InfiniteWidthKernel<Untrained> {
732    /// Create new infinite-width kernel
733    pub fn new(n_layers: usize, activation: Activation) -> Self {
734        Self {
735            n_layers,
736            activation,
737            n_components: 100,
738            x_train: None,
739            eigenvectors: None,
740            _state: PhantomData,
741        }
742    }
743
744    /// Set number of components for dimensionality reduction
745    pub fn n_components(mut self, n: usize) -> Self {
746        self.n_components = n;
747        self
748    }
749
750    /// Compute the NNGP kernel matrix
751    fn compute_nngp_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
752        let n_x = x.nrows();
753        let n_y = y.nrows();
754        let d = x.ncols() as Float;
755
756        // Initialize with normalized dot product
757        let mut kernel = x.dot(&y.t());
758        kernel.mapv_inplace(|k| k / d);
759
760        // Recursively apply activation kernels
761        for _ in 0..self.n_layers {
762            let mut new_kernel = Array2::zeros((n_x, n_y));
763
764            for i in 0..n_x {
765                for j in 0..n_y {
766                    let k_ij = kernel[[i, j]];
767                    let k_ii = if i < n_y { kernel[[i, i]] } else { 1.0 };
768                    let k_jj = if j < n_x { kernel[[j, j]] } else { 1.0 };
769
770                    let norm = (k_ii * k_jj).sqrt().max(1e-10);
771                    let rho = (k_ij / norm).max(-1.0).min(1.0);
772
773                    new_kernel[[i, j]] = norm * self.activation.kernel_value(rho);
774                }
775            }
776
777            kernel = new_kernel;
778        }
779
780        kernel
781    }
782
783    fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
784        let n = kernel.nrows();
785        let mut eigenvectors = Array2::zeros((n, k));
786        let mut kernel_deflated = kernel.clone();
787
788        let mut rng = thread_rng();
789        let normal_dist = Normal::new(0.0, 1.0).expect("operation should succeed");
790
791        for i in 0..k {
792            let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal_dist));
793
794            // Power iteration
795            for _iter in 0..50 {
796                v = kernel_deflated.dot(&v);
797                let norm = v.dot(&v).sqrt();
798                if norm > 1e-10 {
799                    v /= norm;
800                } else {
801                    break;
802                }
803            }
804
805            for j in 0..n {
806                eigenvectors[[j, i]] = v[j];
807            }
808
809            let lambda = v.dot(&kernel_deflated.dot(&v));
810            for row in 0..n {
811                for col in 0..n {
812                    kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
813                }
814            }
815        }
816
817        Ok(eigenvectors)
818    }
819}
820
821impl Estimator for InfiniteWidthKernel<Untrained> {
822    type Config = ();
823    type Error = SklearsError;
824    type Float = Float;
825
826    fn config(&self) -> &Self::Config {
827        &()
828    }
829}
830
831impl Fit<Array2<Float>, ()> for InfiniteWidthKernel<Untrained> {
832    type Fitted = InfiniteWidthKernel<Trained>;
833
834    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
835        if x.nrows() == 0 || x.ncols() == 0 {
836            return Err(SklearsError::InvalidInput(
837                "Input array cannot be empty".to_string(),
838            ));
839        }
840
841        let x_train = x.clone();
842        let kernel = self.compute_nngp_kernel(x, x);
843
844        // Compute eigenvectors using power iteration
845        let n_components = self.n_components.min(x.nrows());
846        let eigenvectors = self.compute_top_eigenvectors(&kernel, n_components)?;
847
848        Ok(InfiniteWidthKernel {
849            n_layers: self.n_layers,
850            activation: self.activation,
851            n_components: self.n_components,
852            x_train: Some(x_train),
853            eigenvectors: Some(eigenvectors),
854            _state: PhantomData,
855        })
856    }
857}
858
859impl Transform<Array2<Float>, Array2<Float>> for InfiniteWidthKernel<Trained> {
860    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
861        let x_train = self.x_train.as_ref().expect("operation should succeed");
862        let eigenvectors = self
863            .eigenvectors
864            .as_ref()
865            .expect("operation should succeed");
866
867        if x.ncols() != x_train.ncols() {
868            return Err(SklearsError::InvalidInput(format!(
869                "Feature dimension mismatch: expected {}, got {}",
870                x_train.ncols(),
871                x.ncols()
872            )));
873        }
874
875        let kernel_obj = InfiniteWidthKernel::<Untrained> {
876            n_layers: self.n_layers,
877            activation: self.activation,
878            n_components: self.n_components,
879            x_train: None,
880            eigenvectors: None,
881            _state: PhantomData,
882        };
883
884        let kernel = kernel_obj.compute_nngp_kernel(x, x_train);
885        Ok(kernel.dot(eigenvectors))
886    }
887}
888
889impl InfiniteWidthKernel<Trained> {
890    /// Get the training data
891    pub fn x_train(&self) -> &Array2<Float> {
892        self.x_train.as_ref().expect("operation should succeed")
893    }
894
895    /// Get the eigenvectors
896    pub fn eigenvectors(&self) -> &Array2<Float> {
897        self.eigenvectors
898            .as_ref()
899            .expect("operation should succeed")
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906    use scirs2_core::ndarray::array;
907
908    #[test]
909    fn test_activation_functions() {
910        let activations = vec![
911            Activation::ReLU,
912            Activation::Tanh,
913            Activation::Sigmoid,
914            Activation::Linear,
915            Activation::GELU,
916            Activation::Swish,
917            Activation::Erf,
918        ];
919
920        for act in activations {
921            let val = act.apply(0.5);
922            assert!(val.is_finite());
923
924            let kernel_val = act.kernel_value(0.5);
925            assert!(kernel_val.is_finite());
926        }
927    }
928
929    #[test]
930    fn test_neural_tangent_kernel_basic() {
931        let config = NTKConfig {
932            n_layers: 2,
933            hidden_width: Some(512),
934            activation: Activation::ReLU,
935            infinite_width: true,
936            weight_variance: 1.0,
937            bias_variance: 0.1,
938        };
939
940        let ntk = NeuralTangentKernel::new(config);
941        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
942
943        let fitted = ntk.fit(&x, &()).expect("operation should succeed");
944        let features = fitted.transform(&x).expect("operation should succeed");
945
946        assert_eq!(features.nrows(), 3);
947        assert!(features.ncols() > 0);
948    }
949
950    #[test]
951    fn test_deep_kernel_learning() {
952        let config = DKLConfig {
953            feature_layers: vec![10, 20],
954            n_components: 50,
955            activation: Activation::ReLU,
956            gamma: 1.0,
957            learning_rate: 0.01,
958        };
959
960        let dkl = DeepKernelLearning::new(config);
961        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
962
963        let fitted = dkl.fit(&x, &()).expect("operation should succeed");
964        let features = fitted.transform(&x).expect("operation should succeed");
965
966        assert_eq!(features.shape(), &[3, 50]);
967    }
968
969    #[test]
970    fn test_infinite_width_kernel() {
971        let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
972        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
973
974        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
975        let features = fitted.transform(&x).expect("operation should succeed");
976
977        assert_eq!(features.nrows(), 4);
978        assert!(features.ncols() > 0);
979    }
980
981    #[test]
982    fn test_ntk_different_activations() {
983        let activations = vec![Activation::ReLU, Activation::Tanh, Activation::GELU];
984        let x = array![[1.0, 2.0], [3.0, 4.0]];
985
986        for act in activations {
987            let ntk = NeuralTangentKernel::with_layers(2).activation(act);
988            let fitted = ntk.fit(&x, &()).expect("operation should succeed");
989            let features = fitted.transform(&x).expect("operation should succeed");
990
991            assert_eq!(features.nrows(), 2);
992        }
993    }
994
995    #[test]
996    fn test_dkl_feature_extraction() {
997        let config = DKLConfig {
998            feature_layers: vec![8, 4],
999            n_components: 20,
1000            activation: Activation::Tanh,
1001            gamma: 0.5,
1002            learning_rate: 0.01,
1003        };
1004
1005        let dkl = DeepKernelLearning::new(config);
1006        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1007
1008        let fitted = dkl.fit(&x, &()).expect("operation should succeed");
1009
1010        // Test that features have correct shape
1011        let features = fitted.transform(&x).expect("operation should succeed");
1012        assert_eq!(features.shape(), &[2, 20]);
1013
1014        // Test that all features are finite
1015        for val in features.iter() {
1016            assert!(val.is_finite());
1017        }
1018    }
1019
1020    #[test]
1021    fn test_empty_input_error() {
1022        let ntk = NeuralTangentKernel::with_layers(2);
1023        let x_empty: Array2<Float> = Array2::zeros((0, 0));
1024
1025        assert!(ntk.fit(&x_empty, &()).is_err());
1026    }
1027
1028    #[test]
1029    fn test_dimension_mismatch_error() {
1030        let ntk = NeuralTangentKernel::with_layers(2);
1031        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
1032        let x_test = array![[1.0, 2.0, 3.0]];
1033
1034        let fitted = ntk.fit(&x_train, &()).expect("operation should succeed");
1035        assert!(fitted.transform(&x_test).is_err());
1036    }
1037}