1use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::{Normal, Uniform};
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result, SklearsError},
19 prelude::{Fit, Transform},
20 traits::{Estimator, Trained, Untrained},
21 types::Float,
22};
23use std::marker::PhantomData;
24
25const PI: Float = std::f64::consts::PI;
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum Activation {
30 ReLU,
32 Tanh,
34 Sigmoid,
36 Erf,
38 Linear,
40 GELU,
42 Swish,
44}
45
46impl Activation {
47 pub fn apply(&self, x: Float) -> Float {
49 match self {
50 Activation::ReLU => x.max(0.0),
51 Activation::Tanh => x.tanh(),
52 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
53 Activation::Erf => {
54 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
56 let x_abs = x.abs();
57 let t = 1.0 / (1.0 + 0.3275911 * x_abs);
58 let approx = 1.0
59 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736)
60 * t
61 + 0.254829592)
62 * t
63 * (-x_abs * x_abs).exp();
64 sign * approx
65 }
66 Activation::Linear => x,
67 Activation::GELU => {
68 let sqrt_2_over_pi = (2.0 / PI).sqrt();
70 0.5 * x * (1.0 + (sqrt_2_over_pi * (x + 0.044715 * x.powi(3))).tanh())
71 }
72 Activation::Swish => {
73 let sigmoid = 1.0 / (1.0 + (-x).exp());
74 x * sigmoid
75 }
76 }
77 }
78
79 pub fn kernel_value(&self, rho: Float) -> Float {
82 match self {
83 Activation::ReLU => {
84 let theta = rho.max(-1.0).min(1.0).acos();
87 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI)
88 }
89 Activation::Tanh => {
90 2.0 / PI * (rho * (1.0 + rho.powi(2)).sqrt()).asin()
92 }
93 Activation::Erf => {
94 2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2)).sqrt())).asin()
96 }
97 Activation::Linear => rho,
98 Activation::Sigmoid => {
99 2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2).abs()).sqrt())).asin()
101 }
102 Activation::GELU => {
103 let theta = rho.max(-1.0).min(1.0).acos();
105 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.702
106 }
107 Activation::Swish => {
108 let theta = rho.max(-1.0).min(1.0).acos();
110 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.5
111 }
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct NTKConfig {
119 pub n_layers: usize,
121 pub hidden_width: Option<usize>,
123 pub activation: Activation,
125 pub infinite_width: bool,
127 pub weight_variance: Float,
129 pub bias_variance: Float,
131}
132
133impl Default for NTKConfig {
134 fn default() -> Self {
135 Self {
136 n_layers: 3,
137 hidden_width: Some(1024),
138 activation: Activation::ReLU,
139 infinite_width: true,
140 weight_variance: 1.0,
141 bias_variance: 1.0,
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
180pub struct NeuralTangentKernel<State = Untrained> {
181 config: NTKConfig,
182 n_components: usize,
183
184 x_train: Option<Array2<Float>>,
186 eigenvectors: Option<Array2<Float>>,
187
188 _state: PhantomData<State>,
189}
190
191impl NeuralTangentKernel<Untrained> {
192 pub fn new(config: NTKConfig) -> Self {
194 Self {
195 config,
196 n_components: 100,
197 x_train: None,
198 eigenvectors: None,
199 _state: PhantomData,
200 }
201 }
202
203 pub fn with_layers(n_layers: usize) -> Self {
205 Self {
206 config: NTKConfig {
207 n_layers,
208 ..Default::default()
209 },
210 n_components: 100,
211 x_train: None,
212 eigenvectors: None,
213 _state: PhantomData,
214 }
215 }
216
217 pub fn activation(mut self, activation: Activation) -> Self {
219 self.config.activation = activation;
220 self
221 }
222
223 pub fn infinite_width(mut self, infinite: bool) -> Self {
225 self.config.infinite_width = infinite;
226 self
227 }
228
229 pub fn n_components(mut self, n: usize) -> Self {
231 self.n_components = n;
232 self
233 }
234
235 fn compute_ntk_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
237 let n_samples_x = x.nrows();
238 let n_samples_y = y.nrows();
239
240 let mut kernel = x.dot(&y.t());
242
243 let d = x.ncols() as Float;
245 kernel.mapv_inplace(|k| k / d);
246
247 for _layer in 0..self.config.n_layers {
249 let mut new_kernel = Array2::zeros((n_samples_x, n_samples_y));
250
251 for i in 0..n_samples_x {
252 for j in 0..n_samples_y {
253 let k_ij = kernel[[i, j]];
254 let k_ii = if i < kernel.nrows() && i < kernel.ncols() {
255 kernel[[i, i]]
256 } else {
257 1.0
258 };
259 let k_jj = if j < kernel.nrows() && j < kernel.ncols() {
260 kernel[[j, j]]
261 } else {
262 1.0
263 };
264
265 let norm = (k_ii * k_jj).sqrt().max(1e-10);
267 let rho = (k_ij / norm).max(-1.0).min(1.0);
268
269 let activated = self.config.activation.kernel_value(rho);
271
272 new_kernel[[i, j]] =
274 self.config.weight_variance * norm * activated + self.config.bias_variance;
275 }
276 }
277
278 kernel = new_kernel;
279 }
280
281 Ok(kernel)
282 }
283
284 fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
286 let n = kernel.nrows();
287 let mut eigenvectors = Array2::zeros((n, k));
288 let mut kernel_deflated = kernel.clone();
289
290 let mut rng = thread_rng();
291 let normal = Normal::new(0.0, 1.0).unwrap();
292
293 for i in 0..k {
294 let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal));
296
297 for _iter in 0..50 {
299 v = kernel_deflated.dot(&v);
300 let norm = v.dot(&v).sqrt();
301 if norm > 1e-10 {
302 v /= norm;
303 } else {
304 break;
305 }
306 }
307
308 for j in 0..n {
310 eigenvectors[[j, i]] = v[j];
311 }
312
313 let lambda = v.dot(&kernel_deflated.dot(&v));
315 for row in 0..n {
316 for col in 0..n {
317 kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
318 }
319 }
320 }
321
322 Ok(eigenvectors)
323 }
324}
325
326impl Estimator for NeuralTangentKernel<Untrained> {
327 type Config = NTKConfig;
328 type Error = SklearsError;
329 type Float = Float;
330
331 fn config(&self) -> &Self::Config {
332 &self.config
333 }
334}
335
336impl Fit<Array2<Float>, ()> for NeuralTangentKernel<Untrained> {
337 type Fitted = NeuralTangentKernel<Trained>;
338
339 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
340 if x.nrows() == 0 || x.ncols() == 0 {
341 return Err(SklearsError::InvalidInput(
342 "Input array cannot be empty".to_string(),
343 ));
344 }
345
346 let x_train = x.clone();
348
349 let kernel = self.compute_ntk_kernel(x, x)?;
351
352 let n_components = x.nrows().min(self.n_components);
354
355 let eigenvectors = if n_components < x.nrows() {
357 Some(self.compute_top_eigenvectors(&kernel, n_components)?)
358 } else {
359 None
360 };
361
362 Ok(NeuralTangentKernel {
363 config: self.config,
364 n_components: self.n_components,
365 x_train: Some(x_train),
366 eigenvectors,
367 _state: PhantomData,
368 })
369 }
370}
371
372impl Transform<Array2<Float>, Array2<Float>> for NeuralTangentKernel<Trained> {
373 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
374 let x_train = self.x_train.as_ref().unwrap();
375
376 if x.ncols() != x_train.ncols() {
377 return Err(SklearsError::InvalidInput(format!(
378 "Feature dimension mismatch: expected {}, got {}",
379 x_train.ncols(),
380 x.ncols()
381 )));
382 }
383
384 let ntk = NeuralTangentKernel::<Untrained> {
386 config: self.config.clone(),
387 n_components: self.n_components,
388 x_train: None,
389 eigenvectors: None,
390 _state: PhantomData,
391 };
392 let kernel = ntk.compute_ntk_kernel(x, x_train)?;
393
394 if let Some(ref eigvecs) = self.eigenvectors {
396 Ok(kernel.dot(eigvecs))
397 } else {
398 Ok(kernel)
399 }
400 }
401}
402
403impl NeuralTangentKernel<Trained> {
404 pub fn x_train(&self) -> &Array2<Float> {
406 self.x_train.as_ref().unwrap()
407 }
408
409 pub fn eigenvectors(&self) -> Option<&Array2<Float>> {
411 self.eigenvectors.as_ref()
412 }
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct DKLConfig {
449 pub feature_layers: Vec<usize>,
451 pub n_components: usize,
453 pub activation: Activation,
455 pub gamma: Float,
457 pub learning_rate: Float,
459}
460
461impl Default for DKLConfig {
462 fn default() -> Self {
463 Self {
464 feature_layers: vec![64, 32],
465 n_components: 100,
466 activation: Activation::ReLU,
467 gamma: 1.0,
468 learning_rate: 0.01,
469 }
470 }
471}
472
473#[derive(Debug, Clone)]
474pub struct DeepKernelLearning<State = Untrained> {
475 config: DKLConfig,
476
477 layer_weights: Option<Vec<Array2<Float>>>,
479 layer_biases: Option<Vec<Array1<Float>>>,
480 random_weights: Option<Array2<Float>>,
481 random_offset: Option<Array1<Float>>,
482
483 _state: PhantomData<State>,
484}
485
486impl DeepKernelLearning<Untrained> {
487 pub fn new(config: DKLConfig) -> Self {
489 Self {
490 config,
491 layer_weights: None,
492 layer_biases: None,
493 random_weights: None,
494 random_offset: None,
495 _state: PhantomData,
496 }
497 }
498
499 pub fn with_components(n_components: usize) -> Self {
501 Self {
502 config: DKLConfig {
503 n_components,
504 ..Default::default()
505 },
506 layer_weights: None,
507 layer_biases: None,
508 random_weights: None,
509 random_offset: None,
510 _state: PhantomData,
511 }
512 }
513
514 pub fn activation(mut self, activation: Activation) -> Self {
516 self.config.activation = activation;
517 self
518 }
519
520 pub fn gamma(mut self, gamma: Float) -> Self {
522 self.config.gamma = gamma;
523 self
524 }
525}
526
527impl Estimator for DeepKernelLearning<Untrained> {
528 type Config = DKLConfig;
529 type Error = SklearsError;
530 type Float = Float;
531
532 fn config(&self) -> &Self::Config {
533 &self.config
534 }
535}
536
537impl Fit<Array2<Float>, ()> for DeepKernelLearning<Untrained> {
538 type Fitted = DeepKernelLearning<Trained>;
539
540 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
541 if x.nrows() == 0 || x.ncols() == 0 {
542 return Err(SklearsError::InvalidInput(
543 "Input array cannot be empty".to_string(),
544 ));
545 }
546
547 let mut rng = thread_rng();
548 let normal_dist = Normal::new(0.0, 1.0).unwrap();
549
550 let mut layer_weights = Vec::new();
552 let mut layer_biases = Vec::new();
553
554 let mut in_features = x.ncols();
555 for &out_features in &self.config.feature_layers {
556 let scale = (2.0 / (in_features + out_features) as Float).sqrt();
558
559 let weights = Array2::from_shape_fn((in_features, out_features), |_| {
560 rng.sample(normal_dist) * scale
561 });
562
563 let biases = Array1::from_shape_fn(out_features, |_| rng.sample(normal_dist) * 0.01);
564
565 layer_weights.push(weights);
566 layer_biases.push(biases);
567 in_features = out_features;
568 }
569
570 let final_features = if self.config.feature_layers.is_empty() {
572 x.ncols()
573 } else {
574 *self.config.feature_layers.last().unwrap()
575 };
576
577 let random_weights =
578 Array2::from_shape_fn((final_features, self.config.n_components), |_| {
579 rng.sample(normal_dist) * (2.0 * self.config.gamma).sqrt()
580 });
581
582 let uniform_dist = Uniform::new(0.0, 2.0 * PI).unwrap();
583 let random_offset =
584 Array1::from_shape_fn(self.config.n_components, |_| rng.sample(uniform_dist));
585
586 Ok(DeepKernelLearning {
587 config: self.config,
588 layer_weights: Some(layer_weights),
589 layer_biases: Some(layer_biases),
590 random_weights: Some(random_weights),
591 random_offset: Some(random_offset),
592 _state: PhantomData,
593 })
594 }
595}
596
597impl DeepKernelLearning<Trained> {
598 fn extract_features(&self, x: &Array2<Float>) -> Array2<Float> {
600 let mut features = x.clone();
601 let layer_weights = self.layer_weights.as_ref().unwrap();
602 let layer_biases = self.layer_biases.as_ref().unwrap();
603
604 for (weights, biases) in layer_weights.iter().zip(layer_biases.iter()) {
605 features = features.dot(weights);
607
608 for i in 0..features.nrows() {
610 for j in 0..features.ncols() {
611 features[[i, j]] += biases[j];
612 }
613 }
614
615 features.mapv_inplace(|v| self.config.activation.apply(v));
617 }
618
619 features
620 }
621}
622
623impl Transform<Array2<Float>, Array2<Float>> for DeepKernelLearning<Trained> {
624 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
625 let deep_features = self.extract_features(x);
627
628 let random_weights = self.random_weights.as_ref().unwrap();
630 let random_offset = self.random_offset.as_ref().unwrap();
631
632 let projection = deep_features.dot(random_weights);
633
634 let n_samples = x.nrows();
635 let mut output = Array2::zeros((n_samples, self.config.n_components));
636
637 let normalizer = (2.0 / self.config.n_components as Float).sqrt();
638 for i in 0..n_samples {
639 for j in 0..self.config.n_components {
640 output[[i, j]] = normalizer * (projection[[i, j]] + random_offset[j]).cos();
641 }
642 }
643
644 Ok(output)
645 }
646}
647
648impl DeepKernelLearning<Trained> {
649 pub fn layer_weights(&self) -> &Vec<Array2<Float>> {
651 self.layer_weights.as_ref().unwrap()
652 }
653
654 pub fn layer_biases(&self) -> &Vec<Array1<Float>> {
656 self.layer_biases.as_ref().unwrap()
657 }
658
659 pub fn random_weights(&self) -> &Array2<Float> {
661 self.random_weights.as_ref().unwrap()
662 }
663
664 pub fn random_offset(&self) -> &Array1<Float> {
666 self.random_offset.as_ref().unwrap()
667 }
668}
669
670#[derive(Debug, Clone)]
695pub struct InfiniteWidthKernel<State = Untrained> {
696 n_layers: usize,
697 activation: Activation,
698 n_components: usize,
699
700 x_train: Option<Array2<Float>>,
702 eigenvectors: Option<Array2<Float>>,
703
704 _state: PhantomData<State>,
705}
706
707impl InfiniteWidthKernel<Untrained> {
708 pub fn new(n_layers: usize, activation: Activation) -> Self {
710 Self {
711 n_layers,
712 activation,
713 n_components: 100,
714 x_train: None,
715 eigenvectors: None,
716 _state: PhantomData,
717 }
718 }
719
720 pub fn n_components(mut self, n: usize) -> Self {
722 self.n_components = n;
723 self
724 }
725
726 fn compute_nngp_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
728 let n_x = x.nrows();
729 let n_y = y.nrows();
730 let d = x.ncols() as Float;
731
732 let mut kernel = x.dot(&y.t());
734 kernel.mapv_inplace(|k| k / d);
735
736 for _ in 0..self.n_layers {
738 let mut new_kernel = Array2::zeros((n_x, n_y));
739
740 for i in 0..n_x {
741 for j in 0..n_y {
742 let k_ij = kernel[[i, j]];
743 let k_ii = if i < n_y { kernel[[i, i]] } else { 1.0 };
744 let k_jj = if j < n_x { kernel[[j, j]] } else { 1.0 };
745
746 let norm = (k_ii * k_jj).sqrt().max(1e-10);
747 let rho = (k_ij / norm).max(-1.0).min(1.0);
748
749 new_kernel[[i, j]] = norm * self.activation.kernel_value(rho);
750 }
751 }
752
753 kernel = new_kernel;
754 }
755
756 kernel
757 }
758
759 fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
760 let n = kernel.nrows();
761 let mut eigenvectors = Array2::zeros((n, k));
762 let mut kernel_deflated = kernel.clone();
763
764 let mut rng = thread_rng();
765 let normal_dist = Normal::new(0.0, 1.0).unwrap();
766
767 for i in 0..k {
768 let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal_dist));
769
770 for _iter in 0..50 {
772 v = kernel_deflated.dot(&v);
773 let norm = v.dot(&v).sqrt();
774 if norm > 1e-10 {
775 v /= norm;
776 } else {
777 break;
778 }
779 }
780
781 for j in 0..n {
782 eigenvectors[[j, i]] = v[j];
783 }
784
785 let lambda = v.dot(&kernel_deflated.dot(&v));
786 for row in 0..n {
787 for col in 0..n {
788 kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
789 }
790 }
791 }
792
793 Ok(eigenvectors)
794 }
795}
796
797impl Estimator for InfiniteWidthKernel<Untrained> {
798 type Config = ();
799 type Error = SklearsError;
800 type Float = Float;
801
802 fn config(&self) -> &Self::Config {
803 &()
804 }
805}
806
807impl Fit<Array2<Float>, ()> for InfiniteWidthKernel<Untrained> {
808 type Fitted = InfiniteWidthKernel<Trained>;
809
810 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
811 if x.nrows() == 0 || x.ncols() == 0 {
812 return Err(SklearsError::InvalidInput(
813 "Input array cannot be empty".to_string(),
814 ));
815 }
816
817 let x_train = x.clone();
818 let kernel = self.compute_nngp_kernel(x, x);
819
820 let n_components = self.n_components.min(x.nrows());
822 let eigenvectors = self.compute_top_eigenvectors(&kernel, n_components)?;
823
824 Ok(InfiniteWidthKernel {
825 n_layers: self.n_layers,
826 activation: self.activation,
827 n_components: self.n_components,
828 x_train: Some(x_train),
829 eigenvectors: Some(eigenvectors),
830 _state: PhantomData,
831 })
832 }
833}
834
835impl Transform<Array2<Float>, Array2<Float>> for InfiniteWidthKernel<Trained> {
836 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
837 let x_train = self.x_train.as_ref().unwrap();
838 let eigenvectors = self.eigenvectors.as_ref().unwrap();
839
840 if x.ncols() != x_train.ncols() {
841 return Err(SklearsError::InvalidInput(format!(
842 "Feature dimension mismatch: expected {}, got {}",
843 x_train.ncols(),
844 x.ncols()
845 )));
846 }
847
848 let kernel_obj = InfiniteWidthKernel::<Untrained> {
849 n_layers: self.n_layers,
850 activation: self.activation,
851 n_components: self.n_components,
852 x_train: None,
853 eigenvectors: None,
854 _state: PhantomData,
855 };
856
857 let kernel = kernel_obj.compute_nngp_kernel(x, x_train);
858 Ok(kernel.dot(eigenvectors))
859 }
860}
861
862impl InfiniteWidthKernel<Trained> {
863 pub fn x_train(&self) -> &Array2<Float> {
865 self.x_train.as_ref().unwrap()
866 }
867
868 pub fn eigenvectors(&self) -> &Array2<Float> {
870 self.eigenvectors.as_ref().unwrap()
871 }
872}
873
874#[cfg(test)]
875mod tests {
876 use super::*;
877 use scirs2_core::ndarray::array;
878
879 #[test]
880 fn test_activation_functions() {
881 let activations = vec![
882 Activation::ReLU,
883 Activation::Tanh,
884 Activation::Sigmoid,
885 Activation::Linear,
886 Activation::GELU,
887 Activation::Swish,
888 Activation::Erf,
889 ];
890
891 for act in activations {
892 let val = act.apply(0.5);
893 assert!(val.is_finite());
894
895 let kernel_val = act.kernel_value(0.5);
896 assert!(kernel_val.is_finite());
897 }
898 }
899
900 #[test]
901 fn test_neural_tangent_kernel_basic() {
902 let config = NTKConfig {
903 n_layers: 2,
904 hidden_width: Some(512),
905 activation: Activation::ReLU,
906 infinite_width: true,
907 weight_variance: 1.0,
908 bias_variance: 0.1,
909 };
910
911 let ntk = NeuralTangentKernel::new(config);
912 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
913
914 let fitted = ntk.fit(&x, &()).unwrap();
915 let features = fitted.transform(&x).unwrap();
916
917 assert_eq!(features.nrows(), 3);
918 assert!(features.ncols() > 0);
919 }
920
921 #[test]
922 fn test_deep_kernel_learning() {
923 let config = DKLConfig {
924 feature_layers: vec![10, 20],
925 n_components: 50,
926 activation: Activation::ReLU,
927 gamma: 1.0,
928 learning_rate: 0.01,
929 };
930
931 let dkl = DeepKernelLearning::new(config);
932 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
933
934 let fitted = dkl.fit(&x, &()).unwrap();
935 let features = fitted.transform(&x).unwrap();
936
937 assert_eq!(features.shape(), &[3, 50]);
938 }
939
940 #[test]
941 fn test_infinite_width_kernel() {
942 let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
943 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
944
945 let fitted = kernel.fit(&x, &()).unwrap();
946 let features = fitted.transform(&x).unwrap();
947
948 assert_eq!(features.nrows(), 4);
949 assert!(features.ncols() > 0);
950 }
951
952 #[test]
953 fn test_ntk_different_activations() {
954 let activations = vec![Activation::ReLU, Activation::Tanh, Activation::GELU];
955 let x = array![[1.0, 2.0], [3.0, 4.0]];
956
957 for act in activations {
958 let ntk = NeuralTangentKernel::with_layers(2).activation(act);
959 let fitted = ntk.fit(&x, &()).unwrap();
960 let features = fitted.transform(&x).unwrap();
961
962 assert_eq!(features.nrows(), 2);
963 }
964 }
965
966 #[test]
967 fn test_dkl_feature_extraction() {
968 let config = DKLConfig {
969 feature_layers: vec![8, 4],
970 n_components: 20,
971 activation: Activation::Tanh,
972 gamma: 0.5,
973 learning_rate: 0.01,
974 };
975
976 let dkl = DeepKernelLearning::new(config);
977 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
978
979 let fitted = dkl.fit(&x, &()).unwrap();
980
981 let features = fitted.transform(&x).unwrap();
983 assert_eq!(features.shape(), &[2, 20]);
984
985 for val in features.iter() {
987 assert!(val.is_finite());
988 }
989 }
990
991 #[test]
992 fn test_empty_input_error() {
993 let ntk = NeuralTangentKernel::with_layers(2);
994 let x_empty: Array2<Float> = Array2::zeros((0, 0));
995
996 assert!(ntk.fit(&x_empty, &()).is_err());
997 }
998
999 #[test]
1000 fn test_dimension_mismatch_error() {
1001 let ntk = NeuralTangentKernel::with_layers(2);
1002 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
1003 let x_test = array![[1.0, 2.0, 3.0]];
1004
1005 let fitted = ntk.fit(&x_train, &()).unwrap();
1006 assert!(fitted.transform(&x_test).is_err());
1007 }
1008}