1use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::{thread_rng, Distribution};
16use sklears_core::{
17 error::{Result, SklearsError},
18 prelude::{Fit, Transform},
19 traits::{Trained, Untrained},
20 types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25pub struct GenomicKernel<State = Untrained> {
37 k: usize,
39 n_components: usize,
41 normalize: bool,
43 projection: Option<Array2<Float>>,
45 kmer_vocab: Option<HashMap<String, usize>>,
47 _state: PhantomData<State>,
49}
50
51impl GenomicKernel<Untrained> {
52 pub fn new(k: usize, n_components: usize) -> Self {
54 Self {
55 k,
56 n_components,
57 normalize: true,
58 projection: None,
59 kmer_vocab: None,
60 _state: PhantomData,
61 }
62 }
63
64 pub fn normalize(mut self, normalize: bool) -> Self {
66 self.normalize = normalize;
67 self
68 }
69}
70
71impl Default for GenomicKernel<Untrained> {
72 fn default() -> Self {
73 Self::new(3, 100)
74 }
75}
76
77impl Fit<Array2<Float>, ()> for GenomicKernel<Untrained> {
78 type Fitted = GenomicKernel<Trained>;
79
80 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
81 let n_samples = x.nrows();
82 if n_samples == 0 {
83 return Err(SklearsError::InvalidInput(
84 "Input array cannot be empty".to_string(),
85 ));
86 }
87
88 let vocab_size = 4usize.pow(self.k as u32);
91 let mut kmer_vocab = HashMap::new();
92
93 for i in 0..vocab_size {
95 let kmer = format!("kmer_{}", i);
96 kmer_vocab.insert(kmer, i);
97 }
98
99 let mut rng = thread_rng();
101 let normal = Normal::new(0.0, 1.0 / (vocab_size as Float).sqrt()).unwrap();
102
103 let mut projection = Array2::zeros((vocab_size, self.n_components));
104 for i in 0..vocab_size {
105 for j in 0..self.n_components {
106 projection[[i, j]] = normal.sample(&mut rng);
107 }
108 }
109
110 Ok(GenomicKernel {
111 k: self.k,
112 n_components: self.n_components,
113 normalize: self.normalize,
114 projection: Some(projection),
115 kmer_vocab: Some(kmer_vocab),
116 _state: PhantomData,
117 })
118 }
119}
120
121impl Transform<Array2<Float>, Array2<Float>> for GenomicKernel<Trained> {
122 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
123 let n_samples = x.nrows();
124 let n_features = x.ncols();
125
126 if n_samples == 0 {
127 return Err(SklearsError::InvalidInput(
128 "Input array cannot be empty".to_string(),
129 ));
130 }
131
132 let projection = self.projection.as_ref().unwrap();
133 let vocab_size = projection.nrows();
134
135 let mut kmer_counts = Array2::zeros((n_samples, vocab_size));
137
138 for i in 0..n_samples {
139 for j in 0..n_features.min(vocab_size) {
140 kmer_counts[[i, j]] = x[[i, j % n_features]].abs();
142 }
143
144 if self.normalize {
146 let row_sum: Float = kmer_counts.row(i).sum();
147 if row_sum > 0.0 {
148 for j in 0..vocab_size {
149 kmer_counts[[i, j]] /= row_sum;
150 }
151 }
152 }
153 }
154
155 let features = kmer_counts.dot(projection);
157
158 Ok(features)
159 }
160}
161
162pub struct ProteinKernel<State = Untrained> {
174 pattern_length: usize,
176 n_components: usize,
178 use_properties: bool,
180 projection: Option<Array2<Float>>,
182 property_weights: Option<Array1<Float>>,
184 _state: PhantomData<State>,
186}
187
188impl ProteinKernel<Untrained> {
189 pub fn new(pattern_length: usize, n_components: usize) -> Self {
191 Self {
192 pattern_length,
193 n_components,
194 use_properties: true,
195 projection: None,
196 property_weights: None,
197 _state: PhantomData,
198 }
199 }
200
201 pub fn use_properties(mut self, use_properties: bool) -> Self {
203 self.use_properties = use_properties;
204 self
205 }
206}
207
208impl Default for ProteinKernel<Untrained> {
209 fn default() -> Self {
210 Self::new(3, 100)
211 }
212}
213
214impl Fit<Array2<Float>, ()> for ProteinKernel<Untrained> {
215 type Fitted = ProteinKernel<Trained>;
216
217 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
218 let n_samples = x.nrows();
219 let _n_features = x.ncols();
220
221 if n_samples == 0 {
222 return Err(SklearsError::InvalidInput(
223 "Input array cannot be empty".to_string(),
224 ));
225 }
226
227 let feature_dim = if self.use_properties { 20 + 5 } else { 20 };
229
230 let mut rng = thread_rng();
231 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
232
233 let mut projection = Array2::zeros((feature_dim * self.pattern_length, self.n_components));
235 for i in 0..(feature_dim * self.pattern_length) {
236 for j in 0..self.n_components {
237 projection[[i, j]] = normal.sample(&mut rng);
238 }
239 }
240
241 let property_weights = if self.use_properties {
243 Some(Array1::from_vec(vec![1.0, 0.8, 0.6, 0.7, 0.5]))
244 } else {
245 None
246 };
247
248 Ok(ProteinKernel {
249 pattern_length: self.pattern_length,
250 n_components: self.n_components,
251 use_properties: self.use_properties,
252 projection: Some(projection),
253 property_weights,
254 _state: PhantomData,
255 })
256 }
257}
258
259impl Transform<Array2<Float>, Array2<Float>> for ProteinKernel<Trained> {
260 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
261 let n_samples = x.nrows();
262 let n_features = x.ncols();
263
264 if n_samples == 0 {
265 return Err(SklearsError::InvalidInput(
266 "Input array cannot be empty".to_string(),
267 ));
268 }
269
270 let projection = self.projection.as_ref().unwrap();
271 let feature_dim = projection.nrows();
272
273 let mut protein_features = Array2::zeros((n_samples, feature_dim));
275
276 for i in 0..n_samples {
277 for j in 0..n_features.min(feature_dim) {
278 let aa_value = x[[i, j % n_features]].abs();
280 protein_features[[i, j]] = aa_value;
281
282 if self.use_properties && j + 20 < feature_dim {
284 if let Some(weights) = &self.property_weights {
285 for (prop_idx, &weight) in weights.iter().enumerate() {
286 if j + 20 + prop_idx < feature_dim {
287 protein_features[[i, j + 20 + prop_idx]] = aa_value * weight;
288 }
289 }
290 }
291 }
292 }
293 }
294
295 let features = protein_features.dot(projection);
297
298 Ok(features)
299 }
300}
301
302pub struct PhylogeneticKernel<State = Untrained> {
314 n_components: usize,
316 tree_depth: usize,
318 use_branch_lengths: bool,
320 projection: Option<Array2<Float>>,
322 branch_weights: Option<Array1<Float>>,
324 _state: PhantomData<State>,
326}
327
328impl PhylogeneticKernel<Untrained> {
329 pub fn new(n_components: usize, tree_depth: usize) -> Self {
331 Self {
332 n_components,
333 tree_depth,
334 use_branch_lengths: true,
335 projection: None,
336 branch_weights: None,
337 _state: PhantomData,
338 }
339 }
340
341 pub fn use_branch_lengths(mut self, use_branch_lengths: bool) -> Self {
343 self.use_branch_lengths = use_branch_lengths;
344 self
345 }
346}
347
348impl Default for PhylogeneticKernel<Untrained> {
349 fn default() -> Self {
350 Self::new(100, 5)
351 }
352}
353
354impl Fit<Array2<Float>, ()> for PhylogeneticKernel<Untrained> {
355 type Fitted = PhylogeneticKernel<Trained>;
356
357 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
358 let n_samples = x.nrows();
359 if n_samples == 0 {
360 return Err(SklearsError::InvalidInput(
361 "Input array cannot be empty".to_string(),
362 ));
363 }
364
365 let feature_dim = 2usize.pow(self.tree_depth as u32);
367
368 let mut rng = thread_rng();
369 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
370
371 let mut projection = Array2::zeros((feature_dim, self.n_components));
373 for i in 0..feature_dim {
374 for j in 0..self.n_components {
375 projection[[i, j]] = normal.sample(&mut rng);
376 }
377 }
378
379 let branch_weights = if self.use_branch_lengths {
381 let mut weights = Array1::zeros(self.tree_depth);
382 for i in 0..self.tree_depth {
383 weights[i] = (-(i as Float) * 0.5).exp();
384 }
385 Some(weights)
386 } else {
387 None
388 };
389
390 Ok(PhylogeneticKernel {
391 n_components: self.n_components,
392 tree_depth: self.tree_depth,
393 use_branch_lengths: self.use_branch_lengths,
394 projection: Some(projection),
395 branch_weights,
396 _state: PhantomData,
397 })
398 }
399}
400
401impl Transform<Array2<Float>, Array2<Float>> for PhylogeneticKernel<Trained> {
402 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403 let n_samples = x.nrows();
404 let n_features = x.ncols();
405
406 if n_samples == 0 {
407 return Err(SklearsError::InvalidInput(
408 "Input array cannot be empty".to_string(),
409 ));
410 }
411
412 let projection = self.projection.as_ref().unwrap();
413 let feature_dim = projection.nrows();
414
415 let mut tree_features = Array2::zeros((n_samples, feature_dim));
417
418 for i in 0..n_samples {
419 for j in 0..n_features.min(feature_dim) {
420 let base_value = x[[i, j % n_features]].abs();
421
422 if self.use_branch_lengths {
424 if let Some(weights) = &self.branch_weights {
425 let depth_idx = j % self.tree_depth;
426 tree_features[[i, j]] = base_value * weights[depth_idx];
427 } else {
428 tree_features[[i, j]] = base_value;
429 }
430 } else {
431 tree_features[[i, j]] = base_value;
432 }
433 }
434 }
435
436 let features = tree_features.dot(projection);
438
439 Ok(features)
440 }
441}
442
443pub struct MetabolicNetworkKernel<State = Untrained> {
455 n_components: usize,
457 max_path_length: usize,
459 use_pathway_enrichment: bool,
461 projection: Option<Array2<Float>>,
463 pathway_weights: Option<Array1<Float>>,
465 _state: PhantomData<State>,
467}
468
469impl MetabolicNetworkKernel<Untrained> {
470 pub fn new(n_components: usize, max_path_length: usize) -> Self {
472 Self {
473 n_components,
474 max_path_length,
475 use_pathway_enrichment: true,
476 projection: None,
477 pathway_weights: None,
478 _state: PhantomData,
479 }
480 }
481
482 pub fn use_pathway_enrichment(mut self, use_pathway_enrichment: bool) -> Self {
484 self.use_pathway_enrichment = use_pathway_enrichment;
485 self
486 }
487}
488
489impl Default for MetabolicNetworkKernel<Untrained> {
490 fn default() -> Self {
491 Self::new(100, 4)
492 }
493}
494
495impl Fit<Array2<Float>, ()> for MetabolicNetworkKernel<Untrained> {
496 type Fitted = MetabolicNetworkKernel<Trained>;
497
498 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
499 let n_samples = x.nrows();
500 if n_samples == 0 {
501 return Err(SklearsError::InvalidInput(
502 "Input array cannot be empty".to_string(),
503 ));
504 }
505
506 let base_dim = 50; let pathway_dim = if self.use_pathway_enrichment { 20 } else { 0 };
509 let feature_dim = base_dim + pathway_dim;
510
511 let mut rng = thread_rng();
512 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
513
514 let mut projection = Array2::zeros((feature_dim, self.n_components));
516 for i in 0..feature_dim {
517 for j in 0..self.n_components {
518 projection[[i, j]] = normal.sample(&mut rng);
519 }
520 }
521
522 let pathway_weights = if self.use_pathway_enrichment {
524 let mut weights = Array1::zeros(pathway_dim);
525 for i in 0..pathway_dim {
526 weights[i] = 1.0 / (1.0 + (i as Float) * 0.1);
528 }
529 Some(weights)
530 } else {
531 None
532 };
533
534 Ok(MetabolicNetworkKernel {
535 n_components: self.n_components,
536 max_path_length: self.max_path_length,
537 use_pathway_enrichment: self.use_pathway_enrichment,
538 projection: Some(projection),
539 pathway_weights,
540 _state: PhantomData,
541 })
542 }
543}
544
545impl Transform<Array2<Float>, Array2<Float>> for MetabolicNetworkKernel<Trained> {
546 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
547 let n_samples = x.nrows();
548 let n_features = x.ncols();
549
550 if n_samples == 0 {
551 return Err(SklearsError::InvalidInput(
552 "Input array cannot be empty".to_string(),
553 ));
554 }
555
556 let projection = self.projection.as_ref().unwrap();
557 let feature_dim = projection.nrows();
558
559 let mut network_features = Array2::zeros((n_samples, feature_dim));
561
562 for i in 0..n_samples {
563 for j in 0..n_features.min(feature_dim) {
565 network_features[[i, j]] = x[[i, j % n_features]].abs();
566 }
567
568 if self.use_pathway_enrichment {
570 if let Some(weights) = &self.pathway_weights {
571 let pathway_start = 50;
572 for (pathway_idx, &weight) in weights.iter().enumerate() {
573 if pathway_start + pathway_idx < feature_dim {
574 let pathway_value = x[[i, pathway_idx % n_features]].abs() * weight;
576 network_features[[i, pathway_start + pathway_idx]] = pathway_value;
577 }
578 }
579 }
580 }
581 }
582
583 let features = network_features.dot(projection);
585
586 Ok(features)
587 }
588}
589
590#[derive(Debug, Clone, Copy)]
596pub enum OmicsIntegrationMethod {
597 Concatenation,
599 WeightedAverage,
601 CrossCorrelation,
603 MultiViewLearning,
605}
606
607pub struct MultiOmicsKernel<State = Untrained> {
615 n_components: usize,
617 n_omics_types: usize,
619 integration_method: OmicsIntegrationMethod,
621 projections: Option<Vec<Array2<Float>>>,
623 omics_weights: Option<Array1<Float>>,
625 _state: PhantomData<State>,
627}
628
629impl MultiOmicsKernel<Untrained> {
630 pub fn new(n_components: usize, n_omics_types: usize) -> Self {
632 Self {
633 n_components,
634 n_omics_types,
635 integration_method: OmicsIntegrationMethod::WeightedAverage,
636 projections: None,
637 omics_weights: None,
638 _state: PhantomData,
639 }
640 }
641
642 pub fn integration_method(mut self, method: OmicsIntegrationMethod) -> Self {
644 self.integration_method = method;
645 self
646 }
647}
648
649impl Default for MultiOmicsKernel<Untrained> {
650 fn default() -> Self {
651 Self::new(100, 3)
652 }
653}
654
655impl Fit<Array2<Float>, ()> for MultiOmicsKernel<Untrained> {
656 type Fitted = MultiOmicsKernel<Trained>;
657
658 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
659 let n_samples = x.nrows();
660 let n_features = x.ncols();
661
662 if n_samples == 0 {
663 return Err(SklearsError::InvalidInput(
664 "Input array cannot be empty".to_string(),
665 ));
666 }
667
668 let features_per_omics = n_features / self.n_omics_types;
670
671 let mut rng = thread_rng();
672
673 let mut projections = Vec::new();
675 for _ in 0..self.n_omics_types {
676 let normal = Normal::new(0.0, 1.0 / (features_per_omics as Float).sqrt()).unwrap();
677 let mut projection = Array2::zeros((features_per_omics, self.n_components));
678
679 for i in 0..features_per_omics {
680 for j in 0..self.n_components {
681 projection[[i, j]] = normal.sample(&mut rng);
682 }
683 }
684 projections.push(projection);
685 }
686
687 let mut omics_weights = Array1::zeros(self.n_omics_types);
689 for i in 0..self.n_omics_types {
690 omics_weights[i] = 1.0 / (1.0 + (i as Float) * 0.2);
692 }
693
694 Ok(MultiOmicsKernel {
695 n_components: self.n_components,
696 n_omics_types: self.n_omics_types,
697 integration_method: self.integration_method,
698 projections: Some(projections),
699 omics_weights: Some(omics_weights),
700 _state: PhantomData,
701 })
702 }
703}
704
705impl Transform<Array2<Float>, Array2<Float>> for MultiOmicsKernel<Trained> {
706 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
707 let n_samples = x.nrows();
708 let n_features = x.ncols();
709
710 if n_samples == 0 {
711 return Err(SklearsError::InvalidInput(
712 "Input array cannot be empty".to_string(),
713 ));
714 }
715
716 let projections = self.projections.as_ref().unwrap();
717 let omics_weights = self.omics_weights.as_ref().unwrap();
718 let features_per_omics = n_features / self.n_omics_types;
719
720 let mut result = Array2::zeros((n_samples, self.n_components));
721
722 match self.integration_method {
723 OmicsIntegrationMethod::Concatenation => {
724 for omics_idx in 0..self.n_omics_types {
727 let start_idx = omics_idx * features_per_omics;
728 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
729
730 if start_idx < n_features {
731 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
732 for i in 0..n_samples {
733 for j in 0..(end_idx - start_idx) {
734 omics_data[[i, j]] = x[[i, start_idx + j]];
735 }
736 }
737 let omics_features = omics_data.dot(&projections[omics_idx]);
738 result += &omics_features;
739 }
740 }
741 result /= self.n_omics_types as Float;
742 }
743 OmicsIntegrationMethod::WeightedAverage => {
744 for omics_idx in 0..self.n_omics_types {
746 let start_idx = omics_idx * features_per_omics;
747 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
748
749 if start_idx < n_features {
750 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
751 for i in 0..n_samples {
752 for j in 0..(end_idx - start_idx) {
753 omics_data[[i, j]] = x[[i, start_idx + j]];
754 }
755 }
756 let omics_features = omics_data.dot(&projections[omics_idx]);
757 let weight = omics_weights[omics_idx];
758 result += &(omics_features * weight);
759 }
760 }
761 let weight_sum: Float = omics_weights.sum();
763 result /= weight_sum;
764 }
765 OmicsIntegrationMethod::CrossCorrelation => {
766 for omics_idx in 0..self.n_omics_types {
768 let start_idx = omics_idx * features_per_omics;
769 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
770
771 if start_idx < n_features {
772 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
773 for i in 0..n_samples {
774 for j in 0..(end_idx - start_idx) {
775 omics_data[[i, j]] = x[[i, start_idx + j]];
776 }
777 }
778 let mut omics_features = omics_data.dot(&projections[omics_idx]);
779
780 for other_idx in 0..self.n_omics_types {
782 if other_idx != omics_idx {
783 let other_start = other_idx * features_per_omics;
784 let other_end =
785 ((other_idx + 1) * features_per_omics).min(n_features);
786
787 if other_start < n_features {
788 let mut other_data =
789 Array2::zeros((n_samples, other_end - other_start));
790 for i in 0..n_samples {
791 for j in 0..(other_end - other_start) {
792 other_data[[i, j]] = x[[i, other_start + j]];
793 }
794 }
795 let other_features = other_data.dot(&projections[other_idx]);
796 omics_features += &(&other_features * 0.1);
798 }
799 }
800 }
801
802 result += &omics_features;
803 }
804 }
805 result /= self.n_omics_types as Float;
806 }
807 OmicsIntegrationMethod::MultiViewLearning => {
808 let mut view_features = Vec::new();
810
811 for omics_idx in 0..self.n_omics_types {
812 let start_idx = omics_idx * features_per_omics;
813 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
814
815 if start_idx < n_features {
816 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
817 for i in 0..n_samples {
818 for j in 0..(end_idx - start_idx) {
819 omics_data[[i, j]] = x[[i, start_idx + j]];
820 }
821 }
822 let omics_features = omics_data.dot(&projections[omics_idx]);
823 view_features.push(omics_features);
824 }
825 }
826
827 for (idx, features) in view_features.iter().enumerate() {
829 let weight = omics_weights[idx];
830 result += &(features * weight);
831 }
832 let weight_sum: Float = omics_weights.sum();
833 result /= weight_sum;
834 }
835 }
836
837 Ok(result)
838 }
839}
840
841#[cfg(test)]
846mod tests {
847 use super::*;
848 use scirs2_core::ndarray::array;
849
850 #[test]
851 fn test_genomic_kernel_basic() {
852 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
853
854 let kernel = GenomicKernel::new(3, 50);
855 let fitted = kernel.fit(&x, &()).unwrap();
856 let features = fitted.transform(&x).unwrap();
857
858 assert_eq!(features.shape(), &[3, 50]);
859 }
860
861 #[test]
862 fn test_genomic_kernel_normalization() {
863 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
864
865 let kernel = GenomicKernel::new(3, 30).normalize(false);
866 let fitted = kernel.fit(&x, &()).unwrap();
867 let features = fitted.transform(&x).unwrap();
868
869 assert_eq!(features.shape(), &[2, 30]);
870 }
871
872 #[test]
873 fn test_protein_kernel_basic() {
874 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
875
876 let kernel = ProteinKernel::new(2, 40);
877 let fitted = kernel.fit(&x, &()).unwrap();
878 let features = fitted.transform(&x).unwrap();
879
880 assert_eq!(features.shape(), &[2, 40]);
881 }
882
883 #[test]
884 fn test_protein_kernel_properties() {
885 let x = array![[1.0, 2.0], [3.0, 4.0]];
886
887 let kernel = ProteinKernel::new(2, 30).use_properties(true);
888 let fitted = kernel.fit(&x, &()).unwrap();
889 let features = fitted.transform(&x).unwrap();
890
891 assert_eq!(features.shape(), &[2, 30]);
892 assert!(fitted.property_weights.is_some());
893 }
894
895 #[test]
896 fn test_phylogenetic_kernel_basic() {
897 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
898
899 let kernel = PhylogeneticKernel::new(50, 4);
900 let fitted = kernel.fit(&x, &()).unwrap();
901 let features = fitted.transform(&x).unwrap();
902
903 assert_eq!(features.shape(), &[2, 50]);
904 }
905
906 #[test]
907 fn test_phylogenetic_kernel_branch_lengths() {
908 let x = array![[1.0, 2.0], [3.0, 4.0]];
909
910 let kernel = PhylogeneticKernel::new(40, 3).use_branch_lengths(true);
911 let fitted = kernel.fit(&x, &()).unwrap();
912 let features = fitted.transform(&x).unwrap();
913
914 assert_eq!(features.shape(), &[2, 40]);
915 assert!(fitted.branch_weights.is_some());
916 }
917
918 #[test]
919 fn test_metabolic_network_kernel_basic() {
920 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
921
922 let kernel = MetabolicNetworkKernel::new(60, 3);
923 let fitted = kernel.fit(&x, &()).unwrap();
924 let features = fitted.transform(&x).unwrap();
925
926 assert_eq!(features.shape(), &[2, 60]);
927 }
928
929 #[test]
930 fn test_metabolic_network_kernel_pathways() {
931 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
932
933 let kernel = MetabolicNetworkKernel::new(50, 3).use_pathway_enrichment(true);
934 let fitted = kernel.fit(&x, &()).unwrap();
935 let features = fitted.transform(&x).unwrap();
936
937 assert_eq!(features.shape(), &[2, 50]);
938 assert!(fitted.pathway_weights.is_some());
939 }
940
941 #[test]
942 fn test_multi_omics_kernel_basic() {
943 let x = array![
944 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
945 [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
946 ];
947
948 let kernel = MultiOmicsKernel::new(40, 3);
949 let fitted = kernel.fit(&x, &()).unwrap();
950 let features = fitted.transform(&x).unwrap();
951
952 assert_eq!(features.shape(), &[2, 40]);
953 }
954
955 #[test]
956 fn test_multi_omics_integration_methods() {
957 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
958
959 let methods = vec![
960 OmicsIntegrationMethod::Concatenation,
961 OmicsIntegrationMethod::WeightedAverage,
962 OmicsIntegrationMethod::CrossCorrelation,
963 OmicsIntegrationMethod::MultiViewLearning,
964 ];
965
966 for method in methods {
967 let kernel = MultiOmicsKernel::new(30, 2).integration_method(method);
968 let fitted = kernel.fit(&x, &()).unwrap();
969 let features = fitted.transform(&x).unwrap();
970 assert_eq!(features.shape(), &[2, 30]);
971 }
972 }
973
974 #[test]
975 fn test_empty_input_error() {
976 let x_empty: Array2<Float> = Array2::zeros((0, 3));
977
978 let kernel = GenomicKernel::new(3, 50);
979 assert!(kernel.fit(&x_empty, &()).is_err());
980
981 let kernel2 = ProteinKernel::new(2, 40);
982 assert!(kernel2.fit(&x_empty, &()).is_err());
983 }
984}