1use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::{Normal, Uniform};
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result, SklearsError},
19 prelude::{Fit, Transform},
20 traits::{Estimator, Trained, Untrained},
21 types::Float,
22};
23use std::marker::PhantomData;
24
25const PI: Float = std::f64::consts::PI;
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum Activation {
30 ReLU,
32 Tanh,
34 Sigmoid,
36 Erf,
38 Linear,
40 GELU,
42 Swish,
44}
45
46impl Activation {
47 pub fn apply(&self, x: Float) -> Float {
49 match self {
50 Activation::ReLU => x.max(0.0),
51 Activation::Tanh => x.tanh(),
52 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
53 Activation::Erf => {
54 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
56 let x_abs = x.abs();
57 let t = 1.0 / (1.0 + 0.3275911 * x_abs);
58 let approx = 1.0
59 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736)
60 * t
61 + 0.254829592)
62 * t
63 * (-x_abs * x_abs).exp();
64 sign * approx
65 }
66 Activation::Linear => x,
67 Activation::GELU => {
68 let sqrt_2_over_pi = (2.0 / PI).sqrt();
70 0.5 * x * (1.0 + (sqrt_2_over_pi * (x + 0.044715 * x.powi(3))).tanh())
71 }
72 Activation::Swish => {
73 let sigmoid = 1.0 / (1.0 + (-x).exp());
74 x * sigmoid
75 }
76 }
77 }
78
79 pub fn kernel_value(&self, rho: Float) -> Float {
82 match self {
83 Activation::ReLU => {
84 let theta = rho.max(-1.0).min(1.0).acos();
87 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI)
88 }
89 Activation::Tanh => {
90 2.0 / PI * (rho * (1.0 + rho.powi(2)).sqrt()).asin()
92 }
93 Activation::Erf => {
94 2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2)).sqrt())).asin()
96 }
97 Activation::Linear => rho,
98 Activation::Sigmoid => {
99 2.0 / PI * (rho / (1.0 + (1.0 - rho.powi(2).abs()).sqrt())).asin()
101 }
102 Activation::GELU => {
103 let theta = rho.max(-1.0).min(1.0).acos();
105 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.702
106 }
107 Activation::Swish => {
108 let theta = rho.max(-1.0).min(1.0).acos();
110 (theta.sin() + (PI - theta) * theta.cos()) / (2.0 * PI) * 1.5
111 }
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct NTKConfig {
119 pub n_layers: usize,
121 pub hidden_width: Option<usize>,
123 pub activation: Activation,
125 pub infinite_width: bool,
127 pub weight_variance: Float,
129 pub bias_variance: Float,
131}
132
133impl Default for NTKConfig {
134 fn default() -> Self {
135 Self {
136 n_layers: 3,
137 hidden_width: Some(1024),
138 activation: Activation::ReLU,
139 infinite_width: true,
140 weight_variance: 1.0,
141 bias_variance: 1.0,
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
180pub struct NeuralTangentKernel<State = Untrained> {
181 config: NTKConfig,
182 n_components: usize,
183
184 x_train: Option<Array2<Float>>,
186 eigenvectors: Option<Array2<Float>>,
187
188 _state: PhantomData<State>,
189}
190
191impl NeuralTangentKernel<Untrained> {
192 pub fn new(config: NTKConfig) -> Self {
194 Self {
195 config,
196 n_components: 100,
197 x_train: None,
198 eigenvectors: None,
199 _state: PhantomData,
200 }
201 }
202
203 pub fn with_layers(n_layers: usize) -> Self {
205 Self {
206 config: NTKConfig {
207 n_layers,
208 ..Default::default()
209 },
210 n_components: 100,
211 x_train: None,
212 eigenvectors: None,
213 _state: PhantomData,
214 }
215 }
216
217 pub fn activation(mut self, activation: Activation) -> Self {
219 self.config.activation = activation;
220 self
221 }
222
223 pub fn infinite_width(mut self, infinite: bool) -> Self {
225 self.config.infinite_width = infinite;
226 self
227 }
228
229 pub fn n_components(mut self, n: usize) -> Self {
231 self.n_components = n;
232 self
233 }
234
235 fn compute_ntk_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Result<Array2<Float>> {
237 let n_samples_x = x.nrows();
238 let n_samples_y = y.nrows();
239
240 let mut kernel = x.dot(&y.t());
242
243 let d = x.ncols() as Float;
245 kernel.mapv_inplace(|k| k / d);
246
247 for _layer in 0..self.config.n_layers {
249 let mut new_kernel = Array2::zeros((n_samples_x, n_samples_y));
250
251 for i in 0..n_samples_x {
252 for j in 0..n_samples_y {
253 let k_ij = kernel[[i, j]];
254 let k_ii = if i < kernel.nrows() && i < kernel.ncols() {
255 kernel[[i, i]]
256 } else {
257 1.0
258 };
259 let k_jj = if j < kernel.nrows() && j < kernel.ncols() {
260 kernel[[j, j]]
261 } else {
262 1.0
263 };
264
265 let norm = (k_ii * k_jj).sqrt().max(1e-10);
267 let rho = (k_ij / norm).max(-1.0).min(1.0);
268
269 let activated = self.config.activation.kernel_value(rho);
271
272 new_kernel[[i, j]] =
274 self.config.weight_variance * norm * activated + self.config.bias_variance;
275 }
276 }
277
278 kernel = new_kernel;
279 }
280
281 Ok(kernel)
282 }
283
284 fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
286 let n = kernel.nrows();
287 let mut eigenvectors = Array2::zeros((n, k));
288 let mut kernel_deflated = kernel.clone();
289
290 let mut rng = thread_rng();
291 let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
292
293 for i in 0..k {
294 let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal));
296
297 for _iter in 0..50 {
299 v = kernel_deflated.dot(&v);
300 let norm = v.dot(&v).sqrt();
301 if norm > 1e-10 {
302 v /= norm;
303 } else {
304 break;
305 }
306 }
307
308 for j in 0..n {
310 eigenvectors[[j, i]] = v[j];
311 }
312
313 let lambda = v.dot(&kernel_deflated.dot(&v));
315 for row in 0..n {
316 for col in 0..n {
317 kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
318 }
319 }
320 }
321
322 Ok(eigenvectors)
323 }
324}
325
326impl Estimator for NeuralTangentKernel<Untrained> {
327 type Config = NTKConfig;
328 type Error = SklearsError;
329 type Float = Float;
330
331 fn config(&self) -> &Self::Config {
332 &self.config
333 }
334}
335
336impl Fit<Array2<Float>, ()> for NeuralTangentKernel<Untrained> {
337 type Fitted = NeuralTangentKernel<Trained>;
338
339 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
340 if x.nrows() == 0 || x.ncols() == 0 {
341 return Err(SklearsError::InvalidInput(
342 "Input array cannot be empty".to_string(),
343 ));
344 }
345
346 let x_train = x.clone();
348
349 let kernel = self.compute_ntk_kernel(x, x)?;
351
352 let n_components = x.nrows().min(self.n_components);
354
355 let eigenvectors = if n_components < x.nrows() {
357 Some(self.compute_top_eigenvectors(&kernel, n_components)?)
358 } else {
359 None
360 };
361
362 Ok(NeuralTangentKernel {
363 config: self.config,
364 n_components: self.n_components,
365 x_train: Some(x_train),
366 eigenvectors,
367 _state: PhantomData,
368 })
369 }
370}
371
372impl Transform<Array2<Float>, Array2<Float>> for NeuralTangentKernel<Trained> {
373 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
374 let x_train = self.x_train.as_ref().expect("operation should succeed");
375
376 if x.ncols() != x_train.ncols() {
377 return Err(SklearsError::InvalidInput(format!(
378 "Feature dimension mismatch: expected {}, got {}",
379 x_train.ncols(),
380 x.ncols()
381 )));
382 }
383
384 let ntk = NeuralTangentKernel::<Untrained> {
386 config: self.config.clone(),
387 n_components: self.n_components,
388 x_train: None,
389 eigenvectors: None,
390 _state: PhantomData,
391 };
392 let kernel = ntk.compute_ntk_kernel(x, x_train)?;
393
394 if let Some(ref eigvecs) = self.eigenvectors {
396 Ok(kernel.dot(eigvecs))
397 } else {
398 Ok(kernel)
399 }
400 }
401}
402
403impl NeuralTangentKernel<Trained> {
404 pub fn x_train(&self) -> &Array2<Float> {
406 self.x_train.as_ref().expect("operation should succeed")
407 }
408
409 pub fn eigenvectors(&self) -> Option<&Array2<Float>> {
411 self.eigenvectors.as_ref()
412 }
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
448pub struct DKLConfig {
449 pub feature_layers: Vec<usize>,
451 pub n_components: usize,
453 pub activation: Activation,
455 pub gamma: Float,
457 pub learning_rate: Float,
459}
460
461impl Default for DKLConfig {
462 fn default() -> Self {
463 Self {
464 feature_layers: vec![64, 32],
465 n_components: 100,
466 activation: Activation::ReLU,
467 gamma: 1.0,
468 learning_rate: 0.01,
469 }
470 }
471}
472
473#[derive(Debug, Clone)]
474pub struct DeepKernelLearning<State = Untrained> {
475 config: DKLConfig,
476
477 layer_weights: Option<Vec<Array2<Float>>>,
479 layer_biases: Option<Vec<Array1<Float>>>,
480 random_weights: Option<Array2<Float>>,
481 random_offset: Option<Array1<Float>>,
482
483 _state: PhantomData<State>,
484}
485
486impl DeepKernelLearning<Untrained> {
487 pub fn new(config: DKLConfig) -> Self {
489 Self {
490 config,
491 layer_weights: None,
492 layer_biases: None,
493 random_weights: None,
494 random_offset: None,
495 _state: PhantomData,
496 }
497 }
498
499 pub fn with_components(n_components: usize) -> Self {
501 Self {
502 config: DKLConfig {
503 n_components,
504 ..Default::default()
505 },
506 layer_weights: None,
507 layer_biases: None,
508 random_weights: None,
509 random_offset: None,
510 _state: PhantomData,
511 }
512 }
513
514 pub fn activation(mut self, activation: Activation) -> Self {
516 self.config.activation = activation;
517 self
518 }
519
520 pub fn gamma(mut self, gamma: Float) -> Self {
522 self.config.gamma = gamma;
523 self
524 }
525}
526
527impl Estimator for DeepKernelLearning<Untrained> {
528 type Config = DKLConfig;
529 type Error = SklearsError;
530 type Float = Float;
531
532 fn config(&self) -> &Self::Config {
533 &self.config
534 }
535}
536
537impl Fit<Array2<Float>, ()> for DeepKernelLearning<Untrained> {
538 type Fitted = DeepKernelLearning<Trained>;
539
540 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
541 if x.nrows() == 0 || x.ncols() == 0 {
542 return Err(SklearsError::InvalidInput(
543 "Input array cannot be empty".to_string(),
544 ));
545 }
546
547 let mut rng = thread_rng();
548 let normal_dist = Normal::new(0.0, 1.0).expect("operation should succeed");
549
550 let mut layer_weights = Vec::new();
552 let mut layer_biases = Vec::new();
553
554 let mut in_features = x.ncols();
555 for &out_features in &self.config.feature_layers {
556 let scale = (2.0 / (in_features + out_features) as Float).sqrt();
558
559 let weights = Array2::from_shape_fn((in_features, out_features), |_| {
560 rng.sample(normal_dist) * scale
561 });
562
563 let biases = Array1::from_shape_fn(out_features, |_| rng.sample(normal_dist) * 0.01);
564
565 layer_weights.push(weights);
566 layer_biases.push(biases);
567 in_features = out_features;
568 }
569
570 let final_features = if self.config.feature_layers.is_empty() {
572 x.ncols()
573 } else {
574 *self
575 .config
576 .feature_layers
577 .last()
578 .expect("operation should succeed")
579 };
580
581 let random_weights =
582 Array2::from_shape_fn((final_features, self.config.n_components), |_| {
583 rng.sample(normal_dist) * (2.0 * self.config.gamma).sqrt()
584 });
585
586 let uniform_dist = Uniform::new(0.0, 2.0 * PI).expect("operation should succeed");
587 let random_offset =
588 Array1::from_shape_fn(self.config.n_components, |_| rng.sample(uniform_dist));
589
590 Ok(DeepKernelLearning {
591 config: self.config,
592 layer_weights: Some(layer_weights),
593 layer_biases: Some(layer_biases),
594 random_weights: Some(random_weights),
595 random_offset: Some(random_offset),
596 _state: PhantomData,
597 })
598 }
599}
600
601impl DeepKernelLearning<Trained> {
602 fn extract_features(&self, x: &Array2<Float>) -> Array2<Float> {
604 let mut features = x.clone();
605 let layer_weights = self
606 .layer_weights
607 .as_ref()
608 .expect("operation should succeed");
609 let layer_biases = self
610 .layer_biases
611 .as_ref()
612 .expect("operation should succeed");
613
614 for (weights, biases) in layer_weights.iter().zip(layer_biases.iter()) {
615 features = features.dot(weights);
617
618 for i in 0..features.nrows() {
620 for j in 0..features.ncols() {
621 features[[i, j]] += biases[j];
622 }
623 }
624
625 features.mapv_inplace(|v| self.config.activation.apply(v));
627 }
628
629 features
630 }
631}
632
633impl Transform<Array2<Float>, Array2<Float>> for DeepKernelLearning<Trained> {
634 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
635 let deep_features = self.extract_features(x);
637
638 let random_weights = self
640 .random_weights
641 .as_ref()
642 .expect("operation should succeed");
643 let random_offset = self
644 .random_offset
645 .as_ref()
646 .expect("operation should succeed");
647
648 let projection = deep_features.dot(random_weights);
649
650 let n_samples = x.nrows();
651 let mut output = Array2::zeros((n_samples, self.config.n_components));
652
653 let normalizer = (2.0 / self.config.n_components as Float).sqrt();
654 for i in 0..n_samples {
655 for j in 0..self.config.n_components {
656 output[[i, j]] = normalizer * (projection[[i, j]] + random_offset[j]).cos();
657 }
658 }
659
660 Ok(output)
661 }
662}
663
664impl DeepKernelLearning<Trained> {
665 pub fn layer_weights(&self) -> &Vec<Array2<Float>> {
667 self.layer_weights
668 .as_ref()
669 .expect("operation should succeed")
670 }
671
672 pub fn layer_biases(&self) -> &Vec<Array1<Float>> {
674 self.layer_biases
675 .as_ref()
676 .expect("operation should succeed")
677 }
678
679 pub fn random_weights(&self) -> &Array2<Float> {
681 self.random_weights
682 .as_ref()
683 .expect("operation should succeed")
684 }
685
686 pub fn random_offset(&self) -> &Array1<Float> {
688 self.random_offset
689 .as_ref()
690 .expect("operation should succeed")
691 }
692}
693
694#[derive(Debug, Clone)]
719pub struct InfiniteWidthKernel<State = Untrained> {
720 n_layers: usize,
721 activation: Activation,
722 n_components: usize,
723
724 x_train: Option<Array2<Float>>,
726 eigenvectors: Option<Array2<Float>>,
727
728 _state: PhantomData<State>,
729}
730
731impl InfiniteWidthKernel<Untrained> {
732 pub fn new(n_layers: usize, activation: Activation) -> Self {
734 Self {
735 n_layers,
736 activation,
737 n_components: 100,
738 x_train: None,
739 eigenvectors: None,
740 _state: PhantomData,
741 }
742 }
743
744 pub fn n_components(mut self, n: usize) -> Self {
746 self.n_components = n;
747 self
748 }
749
750 fn compute_nngp_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
752 let n_x = x.nrows();
753 let n_y = y.nrows();
754 let d = x.ncols() as Float;
755
756 let mut kernel = x.dot(&y.t());
758 kernel.mapv_inplace(|k| k / d);
759
760 for _ in 0..self.n_layers {
762 let mut new_kernel = Array2::zeros((n_x, n_y));
763
764 for i in 0..n_x {
765 for j in 0..n_y {
766 let k_ij = kernel[[i, j]];
767 let k_ii = if i < n_y { kernel[[i, i]] } else { 1.0 };
768 let k_jj = if j < n_x { kernel[[j, j]] } else { 1.0 };
769
770 let norm = (k_ii * k_jj).sqrt().max(1e-10);
771 let rho = (k_ij / norm).max(-1.0).min(1.0);
772
773 new_kernel[[i, j]] = norm * self.activation.kernel_value(rho);
774 }
775 }
776
777 kernel = new_kernel;
778 }
779
780 kernel
781 }
782
783 fn compute_top_eigenvectors(&self, kernel: &Array2<Float>, k: usize) -> Result<Array2<Float>> {
784 let n = kernel.nrows();
785 let mut eigenvectors = Array2::zeros((n, k));
786 let mut kernel_deflated = kernel.clone();
787
788 let mut rng = thread_rng();
789 let normal_dist = Normal::new(0.0, 1.0).expect("operation should succeed");
790
791 for i in 0..k {
792 let mut v = Array1::from_shape_fn(n, |_| rng.sample(normal_dist));
793
794 for _iter in 0..50 {
796 v = kernel_deflated.dot(&v);
797 let norm = v.dot(&v).sqrt();
798 if norm > 1e-10 {
799 v /= norm;
800 } else {
801 break;
802 }
803 }
804
805 for j in 0..n {
806 eigenvectors[[j, i]] = v[j];
807 }
808
809 let lambda = v.dot(&kernel_deflated.dot(&v));
810 for row in 0..n {
811 for col in 0..n {
812 kernel_deflated[[row, col]] -= lambda * v[row] * v[col];
813 }
814 }
815 }
816
817 Ok(eigenvectors)
818 }
819}
820
821impl Estimator for InfiniteWidthKernel<Untrained> {
822 type Config = ();
823 type Error = SklearsError;
824 type Float = Float;
825
826 fn config(&self) -> &Self::Config {
827 &()
828 }
829}
830
831impl Fit<Array2<Float>, ()> for InfiniteWidthKernel<Untrained> {
832 type Fitted = InfiniteWidthKernel<Trained>;
833
834 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
835 if x.nrows() == 0 || x.ncols() == 0 {
836 return Err(SklearsError::InvalidInput(
837 "Input array cannot be empty".to_string(),
838 ));
839 }
840
841 let x_train = x.clone();
842 let kernel = self.compute_nngp_kernel(x, x);
843
844 let n_components = self.n_components.min(x.nrows());
846 let eigenvectors = self.compute_top_eigenvectors(&kernel, n_components)?;
847
848 Ok(InfiniteWidthKernel {
849 n_layers: self.n_layers,
850 activation: self.activation,
851 n_components: self.n_components,
852 x_train: Some(x_train),
853 eigenvectors: Some(eigenvectors),
854 _state: PhantomData,
855 })
856 }
857}
858
859impl Transform<Array2<Float>, Array2<Float>> for InfiniteWidthKernel<Trained> {
860 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
861 let x_train = self.x_train.as_ref().expect("operation should succeed");
862 let eigenvectors = self
863 .eigenvectors
864 .as_ref()
865 .expect("operation should succeed");
866
867 if x.ncols() != x_train.ncols() {
868 return Err(SklearsError::InvalidInput(format!(
869 "Feature dimension mismatch: expected {}, got {}",
870 x_train.ncols(),
871 x.ncols()
872 )));
873 }
874
875 let kernel_obj = InfiniteWidthKernel::<Untrained> {
876 n_layers: self.n_layers,
877 activation: self.activation,
878 n_components: self.n_components,
879 x_train: None,
880 eigenvectors: None,
881 _state: PhantomData,
882 };
883
884 let kernel = kernel_obj.compute_nngp_kernel(x, x_train);
885 Ok(kernel.dot(eigenvectors))
886 }
887}
888
889impl InfiniteWidthKernel<Trained> {
890 pub fn x_train(&self) -> &Array2<Float> {
892 self.x_train.as_ref().expect("operation should succeed")
893 }
894
895 pub fn eigenvectors(&self) -> &Array2<Float> {
897 self.eigenvectors
898 .as_ref()
899 .expect("operation should succeed")
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906 use scirs2_core::ndarray::array;
907
908 #[test]
909 fn test_activation_functions() {
910 let activations = vec![
911 Activation::ReLU,
912 Activation::Tanh,
913 Activation::Sigmoid,
914 Activation::Linear,
915 Activation::GELU,
916 Activation::Swish,
917 Activation::Erf,
918 ];
919
920 for act in activations {
921 let val = act.apply(0.5);
922 assert!(val.is_finite());
923
924 let kernel_val = act.kernel_value(0.5);
925 assert!(kernel_val.is_finite());
926 }
927 }
928
929 #[test]
930 fn test_neural_tangent_kernel_basic() {
931 let config = NTKConfig {
932 n_layers: 2,
933 hidden_width: Some(512),
934 activation: Activation::ReLU,
935 infinite_width: true,
936 weight_variance: 1.0,
937 bias_variance: 0.1,
938 };
939
940 let ntk = NeuralTangentKernel::new(config);
941 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
942
943 let fitted = ntk.fit(&x, &()).expect("operation should succeed");
944 let features = fitted.transform(&x).expect("operation should succeed");
945
946 assert_eq!(features.nrows(), 3);
947 assert!(features.ncols() > 0);
948 }
949
950 #[test]
951 fn test_deep_kernel_learning() {
952 let config = DKLConfig {
953 feature_layers: vec![10, 20],
954 n_components: 50,
955 activation: Activation::ReLU,
956 gamma: 1.0,
957 learning_rate: 0.01,
958 };
959
960 let dkl = DeepKernelLearning::new(config);
961 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
962
963 let fitted = dkl.fit(&x, &()).expect("operation should succeed");
964 let features = fitted.transform(&x).expect("operation should succeed");
965
966 assert_eq!(features.shape(), &[3, 50]);
967 }
968
969 #[test]
970 fn test_infinite_width_kernel() {
971 let kernel = InfiniteWidthKernel::new(3, Activation::ReLU);
972 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
973
974 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
975 let features = fitted.transform(&x).expect("operation should succeed");
976
977 assert_eq!(features.nrows(), 4);
978 assert!(features.ncols() > 0);
979 }
980
981 #[test]
982 fn test_ntk_different_activations() {
983 let activations = vec![Activation::ReLU, Activation::Tanh, Activation::GELU];
984 let x = array![[1.0, 2.0], [3.0, 4.0]];
985
986 for act in activations {
987 let ntk = NeuralTangentKernel::with_layers(2).activation(act);
988 let fitted = ntk.fit(&x, &()).expect("operation should succeed");
989 let features = fitted.transform(&x).expect("operation should succeed");
990
991 assert_eq!(features.nrows(), 2);
992 }
993 }
994
995 #[test]
996 fn test_dkl_feature_extraction() {
997 let config = DKLConfig {
998 feature_layers: vec![8, 4],
999 n_components: 20,
1000 activation: Activation::Tanh,
1001 gamma: 0.5,
1002 learning_rate: 0.01,
1003 };
1004
1005 let dkl = DeepKernelLearning::new(config);
1006 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
1007
1008 let fitted = dkl.fit(&x, &()).expect("operation should succeed");
1009
1010 let features = fitted.transform(&x).expect("operation should succeed");
1012 assert_eq!(features.shape(), &[2, 20]);
1013
1014 for val in features.iter() {
1016 assert!(val.is_finite());
1017 }
1018 }
1019
1020 #[test]
1021 fn test_empty_input_error() {
1022 let ntk = NeuralTangentKernel::with_layers(2);
1023 let x_empty: Array2<Float> = Array2::zeros((0, 0));
1024
1025 assert!(ntk.fit(&x_empty, &()).is_err());
1026 }
1027
1028 #[test]
1029 fn test_dimension_mismatch_error() {
1030 let ntk = NeuralTangentKernel::with_layers(2);
1031 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
1032 let x_test = array![[1.0, 2.0, 3.0]];
1033
1034 let fitted = ntk.fit(&x_train, &()).expect("operation should succeed");
1035 assert!(fitted.transform(&x_test).is_err());
1036 }
1037}