1use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::{thread_rng, Distribution};
16use sklears_core::{
17 error::{Result, SklearsError},
18 prelude::{Fit, Transform},
19 traits::{Trained, Untrained},
20 types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25pub struct GenomicKernel<State = Untrained> {
37 k: usize,
39 n_components: usize,
41 normalize: bool,
43 projection: Option<Array2<Float>>,
45 kmer_vocab: Option<HashMap<String, usize>>,
47 _state: PhantomData<State>,
49}
50
51impl GenomicKernel<Untrained> {
52 pub fn new(k: usize, n_components: usize) -> Self {
54 Self {
55 k,
56 n_components,
57 normalize: true,
58 projection: None,
59 kmer_vocab: None,
60 _state: PhantomData,
61 }
62 }
63
64 pub fn normalize(mut self, normalize: bool) -> Self {
66 self.normalize = normalize;
67 self
68 }
69}
70
71impl Default for GenomicKernel<Untrained> {
72 fn default() -> Self {
73 Self::new(3, 100)
74 }
75}
76
77impl Fit<Array2<Float>, ()> for GenomicKernel<Untrained> {
78 type Fitted = GenomicKernel<Trained>;
79
80 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
81 let n_samples = x.nrows();
82 if n_samples == 0 {
83 return Err(SklearsError::InvalidInput(
84 "Input array cannot be empty".to_string(),
85 ));
86 }
87
88 let vocab_size = 4usize.pow(self.k as u32);
91 let mut kmer_vocab = HashMap::new();
92
93 for i in 0..vocab_size {
95 let kmer = format!("kmer_{}", i);
96 kmer_vocab.insert(kmer, i);
97 }
98
99 let mut rng = thread_rng();
101 let normal =
102 Normal::new(0.0, 1.0 / (vocab_size as Float).sqrt()).expect("operation should succeed");
103
104 let mut projection = Array2::zeros((vocab_size, self.n_components));
105 for i in 0..vocab_size {
106 for j in 0..self.n_components {
107 projection[[i, j]] = normal.sample(&mut rng);
108 }
109 }
110
111 Ok(GenomicKernel {
112 k: self.k,
113 n_components: self.n_components,
114 normalize: self.normalize,
115 projection: Some(projection),
116 kmer_vocab: Some(kmer_vocab),
117 _state: PhantomData,
118 })
119 }
120}
121
122impl Transform<Array2<Float>, Array2<Float>> for GenomicKernel<Trained> {
123 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
124 let n_samples = x.nrows();
125 let n_features = x.ncols();
126
127 if n_samples == 0 {
128 return Err(SklearsError::InvalidInput(
129 "Input array cannot be empty".to_string(),
130 ));
131 }
132
133 let projection = self.projection.as_ref().expect("operation should succeed");
134 let vocab_size = projection.nrows();
135
136 let mut kmer_counts = Array2::zeros((n_samples, vocab_size));
138
139 for i in 0..n_samples {
140 for j in 0..n_features.min(vocab_size) {
141 kmer_counts[[i, j]] = x[[i, j % n_features]].abs();
143 }
144
145 if self.normalize {
147 let row_sum: Float = kmer_counts.row(i).sum();
148 if row_sum > 0.0 {
149 for j in 0..vocab_size {
150 kmer_counts[[i, j]] /= row_sum;
151 }
152 }
153 }
154 }
155
156 let features = kmer_counts.dot(projection);
158
159 Ok(features)
160 }
161}
162
163pub struct ProteinKernel<State = Untrained> {
175 pattern_length: usize,
177 n_components: usize,
179 use_properties: bool,
181 projection: Option<Array2<Float>>,
183 property_weights: Option<Array1<Float>>,
185 _state: PhantomData<State>,
187}
188
189impl ProteinKernel<Untrained> {
190 pub fn new(pattern_length: usize, n_components: usize) -> Self {
192 Self {
193 pattern_length,
194 n_components,
195 use_properties: true,
196 projection: None,
197 property_weights: None,
198 _state: PhantomData,
199 }
200 }
201
202 pub fn use_properties(mut self, use_properties: bool) -> Self {
204 self.use_properties = use_properties;
205 self
206 }
207}
208
209impl Default for ProteinKernel<Untrained> {
210 fn default() -> Self {
211 Self::new(3, 100)
212 }
213}
214
215impl Fit<Array2<Float>, ()> for ProteinKernel<Untrained> {
216 type Fitted = ProteinKernel<Trained>;
217
218 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
219 let n_samples = x.nrows();
220 let _n_features = x.ncols();
221
222 if n_samples == 0 {
223 return Err(SklearsError::InvalidInput(
224 "Input array cannot be empty".to_string(),
225 ));
226 }
227
228 let feature_dim = if self.use_properties { 20 + 5 } else { 20 };
230
231 let mut rng = thread_rng();
232 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
233 .expect("operation should succeed");
234
235 let mut projection = Array2::zeros((feature_dim * self.pattern_length, self.n_components));
237 for i in 0..(feature_dim * self.pattern_length) {
238 for j in 0..self.n_components {
239 projection[[i, j]] = normal.sample(&mut rng);
240 }
241 }
242
243 let property_weights = if self.use_properties {
245 Some(Array1::from_vec(vec![1.0, 0.8, 0.6, 0.7, 0.5]))
246 } else {
247 None
248 };
249
250 Ok(ProteinKernel {
251 pattern_length: self.pattern_length,
252 n_components: self.n_components,
253 use_properties: self.use_properties,
254 projection: Some(projection),
255 property_weights,
256 _state: PhantomData,
257 })
258 }
259}
260
261impl Transform<Array2<Float>, Array2<Float>> for ProteinKernel<Trained> {
262 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
263 let n_samples = x.nrows();
264 let n_features = x.ncols();
265
266 if n_samples == 0 {
267 return Err(SklearsError::InvalidInput(
268 "Input array cannot be empty".to_string(),
269 ));
270 }
271
272 let projection = self.projection.as_ref().expect("operation should succeed");
273 let feature_dim = projection.nrows();
274
275 let mut protein_features = Array2::zeros((n_samples, feature_dim));
277
278 for i in 0..n_samples {
279 for j in 0..n_features.min(feature_dim) {
280 let aa_value = x[[i, j % n_features]].abs();
282 protein_features[[i, j]] = aa_value;
283
284 if self.use_properties && j + 20 < feature_dim {
286 if let Some(weights) = &self.property_weights {
287 for (prop_idx, &weight) in weights.iter().enumerate() {
288 if j + 20 + prop_idx < feature_dim {
289 protein_features[[i, j + 20 + prop_idx]] = aa_value * weight;
290 }
291 }
292 }
293 }
294 }
295 }
296
297 let features = protein_features.dot(projection);
299
300 Ok(features)
301 }
302}
303
304pub struct PhylogeneticKernel<State = Untrained> {
316 n_components: usize,
318 tree_depth: usize,
320 use_branch_lengths: bool,
322 projection: Option<Array2<Float>>,
324 branch_weights: Option<Array1<Float>>,
326 _state: PhantomData<State>,
328}
329
330impl PhylogeneticKernel<Untrained> {
331 pub fn new(n_components: usize, tree_depth: usize) -> Self {
333 Self {
334 n_components,
335 tree_depth,
336 use_branch_lengths: true,
337 projection: None,
338 branch_weights: None,
339 _state: PhantomData,
340 }
341 }
342
343 pub fn use_branch_lengths(mut self, use_branch_lengths: bool) -> Self {
345 self.use_branch_lengths = use_branch_lengths;
346 self
347 }
348}
349
350impl Default for PhylogeneticKernel<Untrained> {
351 fn default() -> Self {
352 Self::new(100, 5)
353 }
354}
355
356impl Fit<Array2<Float>, ()> for PhylogeneticKernel<Untrained> {
357 type Fitted = PhylogeneticKernel<Trained>;
358
359 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
360 let n_samples = x.nrows();
361 if n_samples == 0 {
362 return Err(SklearsError::InvalidInput(
363 "Input array cannot be empty".to_string(),
364 ));
365 }
366
367 let feature_dim = 2usize.pow(self.tree_depth as u32);
369
370 let mut rng = thread_rng();
371 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
372 .expect("operation should succeed");
373
374 let mut projection = Array2::zeros((feature_dim, self.n_components));
376 for i in 0..feature_dim {
377 for j in 0..self.n_components {
378 projection[[i, j]] = normal.sample(&mut rng);
379 }
380 }
381
382 let branch_weights = if self.use_branch_lengths {
384 let mut weights = Array1::zeros(self.tree_depth);
385 for i in 0..self.tree_depth {
386 weights[i] = (-(i as Float) * 0.5).exp();
387 }
388 Some(weights)
389 } else {
390 None
391 };
392
393 Ok(PhylogeneticKernel {
394 n_components: self.n_components,
395 tree_depth: self.tree_depth,
396 use_branch_lengths: self.use_branch_lengths,
397 projection: Some(projection),
398 branch_weights,
399 _state: PhantomData,
400 })
401 }
402}
403
404impl Transform<Array2<Float>, Array2<Float>> for PhylogeneticKernel<Trained> {
405 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
406 let n_samples = x.nrows();
407 let n_features = x.ncols();
408
409 if n_samples == 0 {
410 return Err(SklearsError::InvalidInput(
411 "Input array cannot be empty".to_string(),
412 ));
413 }
414
415 let projection = self.projection.as_ref().expect("operation should succeed");
416 let feature_dim = projection.nrows();
417
418 let mut tree_features = Array2::zeros((n_samples, feature_dim));
420
421 for i in 0..n_samples {
422 for j in 0..n_features.min(feature_dim) {
423 let base_value = x[[i, j % n_features]].abs();
424
425 if self.use_branch_lengths {
427 if let Some(weights) = &self.branch_weights {
428 let depth_idx = j % self.tree_depth;
429 tree_features[[i, j]] = base_value * weights[depth_idx];
430 } else {
431 tree_features[[i, j]] = base_value;
432 }
433 } else {
434 tree_features[[i, j]] = base_value;
435 }
436 }
437 }
438
439 let features = tree_features.dot(projection);
441
442 Ok(features)
443 }
444}
445
446pub struct MetabolicNetworkKernel<State = Untrained> {
458 n_components: usize,
460 max_path_length: usize,
462 use_pathway_enrichment: bool,
464 projection: Option<Array2<Float>>,
466 pathway_weights: Option<Array1<Float>>,
468 _state: PhantomData<State>,
470}
471
472impl MetabolicNetworkKernel<Untrained> {
473 pub fn new(n_components: usize, max_path_length: usize) -> Self {
475 Self {
476 n_components,
477 max_path_length,
478 use_pathway_enrichment: true,
479 projection: None,
480 pathway_weights: None,
481 _state: PhantomData,
482 }
483 }
484
485 pub fn use_pathway_enrichment(mut self, use_pathway_enrichment: bool) -> Self {
487 self.use_pathway_enrichment = use_pathway_enrichment;
488 self
489 }
490}
491
492impl Default for MetabolicNetworkKernel<Untrained> {
493 fn default() -> Self {
494 Self::new(100, 4)
495 }
496}
497
498impl Fit<Array2<Float>, ()> for MetabolicNetworkKernel<Untrained> {
499 type Fitted = MetabolicNetworkKernel<Trained>;
500
501 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
502 let n_samples = x.nrows();
503 if n_samples == 0 {
504 return Err(SklearsError::InvalidInput(
505 "Input array cannot be empty".to_string(),
506 ));
507 }
508
509 let base_dim = 50; let pathway_dim = if self.use_pathway_enrichment { 20 } else { 0 };
512 let feature_dim = base_dim + pathway_dim;
513
514 let mut rng = thread_rng();
515 let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
516 .expect("operation should succeed");
517
518 let mut projection = Array2::zeros((feature_dim, self.n_components));
520 for i in 0..feature_dim {
521 for j in 0..self.n_components {
522 projection[[i, j]] = normal.sample(&mut rng);
523 }
524 }
525
526 let pathway_weights = if self.use_pathway_enrichment {
528 let mut weights = Array1::zeros(pathway_dim);
529 for i in 0..pathway_dim {
530 weights[i] = 1.0 / (1.0 + (i as Float) * 0.1);
532 }
533 Some(weights)
534 } else {
535 None
536 };
537
538 Ok(MetabolicNetworkKernel {
539 n_components: self.n_components,
540 max_path_length: self.max_path_length,
541 use_pathway_enrichment: self.use_pathway_enrichment,
542 projection: Some(projection),
543 pathway_weights,
544 _state: PhantomData,
545 })
546 }
547}
548
549impl Transform<Array2<Float>, Array2<Float>> for MetabolicNetworkKernel<Trained> {
550 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
551 let n_samples = x.nrows();
552 let n_features = x.ncols();
553
554 if n_samples == 0 {
555 return Err(SklearsError::InvalidInput(
556 "Input array cannot be empty".to_string(),
557 ));
558 }
559
560 let projection = self.projection.as_ref().expect("operation should succeed");
561 let feature_dim = projection.nrows();
562
563 let mut network_features = Array2::zeros((n_samples, feature_dim));
565
566 for i in 0..n_samples {
567 for j in 0..n_features.min(feature_dim) {
569 network_features[[i, j]] = x[[i, j % n_features]].abs();
570 }
571
572 if self.use_pathway_enrichment {
574 if let Some(weights) = &self.pathway_weights {
575 let pathway_start = 50;
576 for (pathway_idx, &weight) in weights.iter().enumerate() {
577 if pathway_start + pathway_idx < feature_dim {
578 let pathway_value = x[[i, pathway_idx % n_features]].abs() * weight;
580 network_features[[i, pathway_start + pathway_idx]] = pathway_value;
581 }
582 }
583 }
584 }
585 }
586
587 let features = network_features.dot(projection);
589
590 Ok(features)
591 }
592}
593
594#[derive(Debug, Clone, Copy)]
600pub enum OmicsIntegrationMethod {
601 Concatenation,
603 WeightedAverage,
605 CrossCorrelation,
607 MultiViewLearning,
609}
610
611pub struct MultiOmicsKernel<State = Untrained> {
619 n_components: usize,
621 n_omics_types: usize,
623 integration_method: OmicsIntegrationMethod,
625 projections: Option<Vec<Array2<Float>>>,
627 omics_weights: Option<Array1<Float>>,
629 _state: PhantomData<State>,
631}
632
633impl MultiOmicsKernel<Untrained> {
634 pub fn new(n_components: usize, n_omics_types: usize) -> Self {
636 Self {
637 n_components,
638 n_omics_types,
639 integration_method: OmicsIntegrationMethod::WeightedAverage,
640 projections: None,
641 omics_weights: None,
642 _state: PhantomData,
643 }
644 }
645
646 pub fn integration_method(mut self, method: OmicsIntegrationMethod) -> Self {
648 self.integration_method = method;
649 self
650 }
651}
652
653impl Default for MultiOmicsKernel<Untrained> {
654 fn default() -> Self {
655 Self::new(100, 3)
656 }
657}
658
659impl Fit<Array2<Float>, ()> for MultiOmicsKernel<Untrained> {
660 type Fitted = MultiOmicsKernel<Trained>;
661
662 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
663 let n_samples = x.nrows();
664 let n_features = x.ncols();
665
666 if n_samples == 0 {
667 return Err(SklearsError::InvalidInput(
668 "Input array cannot be empty".to_string(),
669 ));
670 }
671
672 let features_per_omics = n_features / self.n_omics_types;
674
675 let mut rng = thread_rng();
676
677 let mut projections = Vec::new();
679 for _ in 0..self.n_omics_types {
680 let normal = Normal::new(0.0, 1.0 / (features_per_omics as Float).sqrt())
681 .expect("operation should succeed");
682 let mut projection = Array2::zeros((features_per_omics, self.n_components));
683
684 for i in 0..features_per_omics {
685 for j in 0..self.n_components {
686 projection[[i, j]] = normal.sample(&mut rng);
687 }
688 }
689 projections.push(projection);
690 }
691
692 let mut omics_weights = Array1::zeros(self.n_omics_types);
694 for i in 0..self.n_omics_types {
695 omics_weights[i] = 1.0 / (1.0 + (i as Float) * 0.2);
697 }
698
699 Ok(MultiOmicsKernel {
700 n_components: self.n_components,
701 n_omics_types: self.n_omics_types,
702 integration_method: self.integration_method,
703 projections: Some(projections),
704 omics_weights: Some(omics_weights),
705 _state: PhantomData,
706 })
707 }
708}
709
710impl Transform<Array2<Float>, Array2<Float>> for MultiOmicsKernel<Trained> {
711 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
712 let n_samples = x.nrows();
713 let n_features = x.ncols();
714
715 if n_samples == 0 {
716 return Err(SklearsError::InvalidInput(
717 "Input array cannot be empty".to_string(),
718 ));
719 }
720
721 let projections = self.projections.as_ref().expect("operation should succeed");
722 let omics_weights = self
723 .omics_weights
724 .as_ref()
725 .expect("operation should succeed");
726 let features_per_omics = n_features / self.n_omics_types;
727
728 let mut result = Array2::zeros((n_samples, self.n_components));
729
730 match self.integration_method {
731 OmicsIntegrationMethod::Concatenation => {
732 for omics_idx in 0..self.n_omics_types {
735 let start_idx = omics_idx * features_per_omics;
736 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
737
738 if start_idx < n_features {
739 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
740 for i in 0..n_samples {
741 for j in 0..(end_idx - start_idx) {
742 omics_data[[i, j]] = x[[i, start_idx + j]];
743 }
744 }
745 let omics_features = omics_data.dot(&projections[omics_idx]);
746 result += &omics_features;
747 }
748 }
749 result /= self.n_omics_types as Float;
750 }
751 OmicsIntegrationMethod::WeightedAverage => {
752 for omics_idx in 0..self.n_omics_types {
754 let start_idx = omics_idx * features_per_omics;
755 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
756
757 if start_idx < n_features {
758 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
759 for i in 0..n_samples {
760 for j in 0..(end_idx - start_idx) {
761 omics_data[[i, j]] = x[[i, start_idx + j]];
762 }
763 }
764 let omics_features = omics_data.dot(&projections[omics_idx]);
765 let weight = omics_weights[omics_idx];
766 result += &(omics_features * weight);
767 }
768 }
769 let weight_sum: Float = omics_weights.sum();
771 result /= weight_sum;
772 }
773 OmicsIntegrationMethod::CrossCorrelation => {
774 for omics_idx in 0..self.n_omics_types {
776 let start_idx = omics_idx * features_per_omics;
777 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
778
779 if start_idx < n_features {
780 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
781 for i in 0..n_samples {
782 for j in 0..(end_idx - start_idx) {
783 omics_data[[i, j]] = x[[i, start_idx + j]];
784 }
785 }
786 let mut omics_features = omics_data.dot(&projections[omics_idx]);
787
788 for other_idx in 0..self.n_omics_types {
790 if other_idx != omics_idx {
791 let other_start = other_idx * features_per_omics;
792 let other_end =
793 ((other_idx + 1) * features_per_omics).min(n_features);
794
795 if other_start < n_features {
796 let mut other_data =
797 Array2::zeros((n_samples, other_end - other_start));
798 for i in 0..n_samples {
799 for j in 0..(other_end - other_start) {
800 other_data[[i, j]] = x[[i, other_start + j]];
801 }
802 }
803 let other_features = other_data.dot(&projections[other_idx]);
804 omics_features += &(&other_features * 0.1);
806 }
807 }
808 }
809
810 result += &omics_features;
811 }
812 }
813 result /= self.n_omics_types as Float;
814 }
815 OmicsIntegrationMethod::MultiViewLearning => {
816 let mut view_features = Vec::new();
818
819 for omics_idx in 0..self.n_omics_types {
820 let start_idx = omics_idx * features_per_omics;
821 let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
822
823 if start_idx < n_features {
824 let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
825 for i in 0..n_samples {
826 for j in 0..(end_idx - start_idx) {
827 omics_data[[i, j]] = x[[i, start_idx + j]];
828 }
829 }
830 let omics_features = omics_data.dot(&projections[omics_idx]);
831 view_features.push(omics_features);
832 }
833 }
834
835 for (idx, features) in view_features.iter().enumerate() {
837 let weight = omics_weights[idx];
838 result += &(features * weight);
839 }
840 let weight_sum: Float = omics_weights.sum();
841 result /= weight_sum;
842 }
843 }
844
845 Ok(result)
846 }
847}
848
849#[cfg(test)]
854mod tests {
855 use super::*;
856 use scirs2_core::ndarray::array;
857
858 #[test]
859 fn test_genomic_kernel_basic() {
860 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
861
862 let kernel = GenomicKernel::new(3, 50);
863 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
864 let features = fitted.transform(&x).expect("operation should succeed");
865
866 assert_eq!(features.shape(), &[3, 50]);
867 }
868
869 #[test]
870 fn test_genomic_kernel_normalization() {
871 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
872
873 let kernel = GenomicKernel::new(3, 30).normalize(false);
874 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
875 let features = fitted.transform(&x).expect("operation should succeed");
876
877 assert_eq!(features.shape(), &[2, 30]);
878 }
879
880 #[test]
881 fn test_protein_kernel_basic() {
882 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
883
884 let kernel = ProteinKernel::new(2, 40);
885 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
886 let features = fitted.transform(&x).expect("operation should succeed");
887
888 assert_eq!(features.shape(), &[2, 40]);
889 }
890
891 #[test]
892 fn test_protein_kernel_properties() {
893 let x = array![[1.0, 2.0], [3.0, 4.0]];
894
895 let kernel = ProteinKernel::new(2, 30).use_properties(true);
896 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
897 let features = fitted.transform(&x).expect("operation should succeed");
898
899 assert_eq!(features.shape(), &[2, 30]);
900 assert!(fitted.property_weights.is_some());
901 }
902
903 #[test]
904 fn test_phylogenetic_kernel_basic() {
905 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
906
907 let kernel = PhylogeneticKernel::new(50, 4);
908 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
909 let features = fitted.transform(&x).expect("operation should succeed");
910
911 assert_eq!(features.shape(), &[2, 50]);
912 }
913
914 #[test]
915 fn test_phylogenetic_kernel_branch_lengths() {
916 let x = array![[1.0, 2.0], [3.0, 4.0]];
917
918 let kernel = PhylogeneticKernel::new(40, 3).use_branch_lengths(true);
919 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
920 let features = fitted.transform(&x).expect("operation should succeed");
921
922 assert_eq!(features.shape(), &[2, 40]);
923 assert!(fitted.branch_weights.is_some());
924 }
925
926 #[test]
927 fn test_metabolic_network_kernel_basic() {
928 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
929
930 let kernel = MetabolicNetworkKernel::new(60, 3);
931 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
932 let features = fitted.transform(&x).expect("operation should succeed");
933
934 assert_eq!(features.shape(), &[2, 60]);
935 }
936
937 #[test]
938 fn test_metabolic_network_kernel_pathways() {
939 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
940
941 let kernel = MetabolicNetworkKernel::new(50, 3).use_pathway_enrichment(true);
942 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
943 let features = fitted.transform(&x).expect("operation should succeed");
944
945 assert_eq!(features.shape(), &[2, 50]);
946 assert!(fitted.pathway_weights.is_some());
947 }
948
949 #[test]
950 fn test_multi_omics_kernel_basic() {
951 let x = array![
952 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
953 [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
954 ];
955
956 let kernel = MultiOmicsKernel::new(40, 3);
957 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
958 let features = fitted.transform(&x).expect("operation should succeed");
959
960 assert_eq!(features.shape(), &[2, 40]);
961 }
962
963 #[test]
964 fn test_multi_omics_integration_methods() {
965 let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
966
967 let methods = vec![
968 OmicsIntegrationMethod::Concatenation,
969 OmicsIntegrationMethod::WeightedAverage,
970 OmicsIntegrationMethod::CrossCorrelation,
971 OmicsIntegrationMethod::MultiViewLearning,
972 ];
973
974 for method in methods {
975 let kernel = MultiOmicsKernel::new(30, 2).integration_method(method);
976 let fitted = kernel.fit(&x, &()).expect("operation should succeed");
977 let features = fitted.transform(&x).expect("operation should succeed");
978 assert_eq!(features.shape(), &[2, 30]);
979 }
980 }
981
982 #[test]
983 fn test_empty_input_error() {
984 let x_empty: Array2<Float> = Array2::zeros((0, 3));
985
986 let kernel = GenomicKernel::new(3, 50);
987 assert!(kernel.fit(&x_empty, &()).is_err());
988
989 let kernel2 = ProteinKernel::new(2, 40);
990 assert!(kernel2.fit(&x_empty, &()).is_err());
991 }
992}