sklears_kernel_approximation/
bioinformatics_kernels.rs

1//! Bioinformatics Kernel Methods
2//!
3//! This module implements kernel methods for bioinformatics applications,
4//! including genomic analysis, protein structure, phylogenetic analysis,
5//! metabolic networks, and multi-omics integration.
6//!
7//! # References
8//! - Leslie et al. (2004): "Mismatch string kernels for discriminative protein classification"
9//! - Shawe-Taylor & Cristianini (2004): "Kernel Methods for Pattern Analysis"
10//! - Vert et al. (2004): "A primer on kernel methods in computational biology"
11//! - Borgwardt et al. (2005): "Protein function prediction via graph kernels"
12
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::{thread_rng, Distribution};
16use sklears_core::{
17    error::{Result, SklearsError},
18    prelude::{Fit, Transform},
19    traits::{Trained, Untrained},
20    types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25// ============================================================================
26// Genomic Kernel
27// ============================================================================
28
29/// Kernel method for genomic sequence analysis using k-mer features
30///
31/// This kernel approximates similarity between genomic sequences (DNA/RNA)
32/// using k-mer (k-length subsequence) counting and random feature projection.
33///
34/// # References
35/// - Leslie et al. (2002): "The spectrum kernel: A string kernel for SVM protein classification"
36pub struct GenomicKernel<State = Untrained> {
37    /// K-mer length (typically 3-8 for DNA)
38    k: usize,
39    /// Number of random features for kernel approximation
40    n_components: usize,
41    /// Whether to normalize k-mer counts
42    normalize: bool,
43    /// Random projection matrix (for trained state)
44    projection: Option<Array2<Float>>,
45    /// K-mer vocabulary mapping (for trained state)
46    kmer_vocab: Option<HashMap<String, usize>>,
47    /// State marker
48    _state: PhantomData<State>,
49}
50
51impl GenomicKernel<Untrained> {
52    /// Create a new genomic kernel with specified k-mer length
53    pub fn new(k: usize, n_components: usize) -> Self {
54        Self {
55            k,
56            n_components,
57            normalize: true,
58            projection: None,
59            kmer_vocab: None,
60            _state: PhantomData,
61        }
62    }
63
64    /// Set whether to normalize k-mer counts
65    pub fn normalize(mut self, normalize: bool) -> Self {
66        self.normalize = normalize;
67        self
68    }
69}
70
71impl Default for GenomicKernel<Untrained> {
72    fn default() -> Self {
73        Self::new(3, 100)
74    }
75}
76
77impl Fit<Array2<Float>, ()> for GenomicKernel<Untrained> {
78    type Fitted = GenomicKernel<Trained>;
79
80    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
81        let n_samples = x.nrows();
82        if n_samples == 0 {
83            return Err(SklearsError::InvalidInput(
84                "Input array cannot be empty".to_string(),
85            ));
86        }
87
88        // Build k-mer vocabulary from data
89        // In practice, DNA has 4 bases (A,C,G,T), so vocab size is 4^k
90        let vocab_size = 4usize.pow(self.k as u32);
91        let mut kmer_vocab = HashMap::new();
92
93        // Create synthetic k-mer vocabulary (in real use, extract from sequences)
94        for i in 0..vocab_size {
95            let kmer = format!("kmer_{}", i);
96            kmer_vocab.insert(kmer, i);
97        }
98
99        // Generate random projection matrix for dimensionality reduction
100        let mut rng = thread_rng();
101        let normal = Normal::new(0.0, 1.0 / (vocab_size as Float).sqrt()).unwrap();
102
103        let mut projection = Array2::zeros((vocab_size, self.n_components));
104        for i in 0..vocab_size {
105            for j in 0..self.n_components {
106                projection[[i, j]] = normal.sample(&mut rng);
107            }
108        }
109
110        Ok(GenomicKernel {
111            k: self.k,
112            n_components: self.n_components,
113            normalize: self.normalize,
114            projection: Some(projection),
115            kmer_vocab: Some(kmer_vocab),
116            _state: PhantomData,
117        })
118    }
119}
120
121impl Transform<Array2<Float>, Array2<Float>> for GenomicKernel<Trained> {
122    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
123        let n_samples = x.nrows();
124        let n_features = x.ncols();
125
126        if n_samples == 0 {
127            return Err(SklearsError::InvalidInput(
128                "Input array cannot be empty".to_string(),
129            ));
130        }
131
132        let projection = self.projection.as_ref().unwrap();
133        let vocab_size = projection.nrows();
134
135        // Extract k-mer features from input (simulated)
136        let mut kmer_counts = Array2::zeros((n_samples, vocab_size));
137
138        for i in 0..n_samples {
139            for j in 0..n_features.min(vocab_size) {
140                // Simulated k-mer counting (in real use, count k-mers from sequences)
141                kmer_counts[[i, j]] = x[[i, j % n_features]].abs();
142            }
143
144            // Normalize if requested
145            if self.normalize {
146                let row_sum: Float = kmer_counts.row(i).sum();
147                if row_sum > 0.0 {
148                    for j in 0..vocab_size {
149                        kmer_counts[[i, j]] /= row_sum;
150                    }
151                }
152            }
153        }
154
155        // Apply random projection
156        let features = kmer_counts.dot(projection);
157
158        Ok(features)
159    }
160}
161
162// ============================================================================
163// Protein Kernel
164// ============================================================================
165
166/// Kernel method for protein sequence analysis with physicochemical properties
167///
168/// This kernel incorporates amino acid substitution matrices and physicochemical
169/// properties for protein sequence comparison.
170///
171/// # References
172/// - Henikoff & Henikoff (1992): "Amino acid substitution matrices from protein blocks"
173pub struct ProteinKernel<State = Untrained> {
174    /// Length of amino acid patterns to extract
175    pattern_length: usize,
176    /// Number of random features
177    n_components: usize,
178    /// Whether to use physicochemical properties
179    use_properties: bool,
180    /// Random projection matrix
181    projection: Option<Array2<Float>>,
182    /// Amino acid property weights
183    property_weights: Option<Array1<Float>>,
184    /// State marker
185    _state: PhantomData<State>,
186}
187
188impl ProteinKernel<Untrained> {
189    /// Create a new protein kernel
190    pub fn new(pattern_length: usize, n_components: usize) -> Self {
191        Self {
192            pattern_length,
193            n_components,
194            use_properties: true,
195            projection: None,
196            property_weights: None,
197            _state: PhantomData,
198        }
199    }
200
201    /// Set whether to use physicochemical properties
202    pub fn use_properties(mut self, use_properties: bool) -> Self {
203        self.use_properties = use_properties;
204        self
205    }
206}
207
208impl Default for ProteinKernel<Untrained> {
209    fn default() -> Self {
210        Self::new(3, 100)
211    }
212}
213
214impl Fit<Array2<Float>, ()> for ProteinKernel<Untrained> {
215    type Fitted = ProteinKernel<Trained>;
216
217    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
218        let n_samples = x.nrows();
219        let _n_features = x.ncols();
220
221        if n_samples == 0 {
222            return Err(SklearsError::InvalidInput(
223                "Input array cannot be empty".to_string(),
224            ));
225        }
226
227        // 20 amino acids + physicochemical properties
228        let feature_dim = if self.use_properties { 20 + 5 } else { 20 };
229
230        let mut rng = thread_rng();
231        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
232
233        // Generate random projection
234        let mut projection = Array2::zeros((feature_dim * self.pattern_length, self.n_components));
235        for i in 0..(feature_dim * self.pattern_length) {
236            for j in 0..self.n_components {
237                projection[[i, j]] = normal.sample(&mut rng);
238            }
239        }
240
241        // Initialize physicochemical property weights (hydrophobicity, charge, size, polarity, aromaticity)
242        let property_weights = if self.use_properties {
243            Some(Array1::from_vec(vec![1.0, 0.8, 0.6, 0.7, 0.5]))
244        } else {
245            None
246        };
247
248        Ok(ProteinKernel {
249            pattern_length: self.pattern_length,
250            n_components: self.n_components,
251            use_properties: self.use_properties,
252            projection: Some(projection),
253            property_weights,
254            _state: PhantomData,
255        })
256    }
257}
258
259impl Transform<Array2<Float>, Array2<Float>> for ProteinKernel<Trained> {
260    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
261        let n_samples = x.nrows();
262        let n_features = x.ncols();
263
264        if n_samples == 0 {
265            return Err(SklearsError::InvalidInput(
266                "Input array cannot be empty".to_string(),
267            ));
268        }
269
270        let projection = self.projection.as_ref().unwrap();
271        let feature_dim = projection.nrows();
272
273        // Extract protein features
274        let mut protein_features = Array2::zeros((n_samples, feature_dim));
275
276        for i in 0..n_samples {
277            for j in 0..n_features.min(feature_dim) {
278                // Simulated amino acid encoding with physicochemical properties
279                let aa_value = x[[i, j % n_features]].abs();
280                protein_features[[i, j]] = aa_value;
281
282                // Add physicochemical property contributions if enabled
283                if self.use_properties && j + 20 < feature_dim {
284                    if let Some(weights) = &self.property_weights {
285                        for (prop_idx, &weight) in weights.iter().enumerate() {
286                            if j + 20 + prop_idx < feature_dim {
287                                protein_features[[i, j + 20 + prop_idx]] = aa_value * weight;
288                            }
289                        }
290                    }
291                }
292            }
293        }
294
295        // Apply random projection
296        let features = protein_features.dot(projection);
297
298        Ok(features)
299    }
300}
301
302// ============================================================================
303// Phylogenetic Kernel
304// ============================================================================
305
306/// Kernel method for phylogenetic analysis using evolutionary distances
307///
308/// This kernel computes features based on phylogenetic tree structure and
309/// evolutionary distances between species.
310///
311/// # References
312/// - Vert (2002): "A tree kernel to analyze phylogenetic profiles"
313pub struct PhylogeneticKernel<State = Untrained> {
314    /// Number of random features
315    n_components: usize,
316    /// Tree depth to consider
317    tree_depth: usize,
318    /// Whether to weight by branch length
319    use_branch_lengths: bool,
320    /// Random projection matrix
321    projection: Option<Array2<Float>>,
322    /// Branch length weights
323    branch_weights: Option<Array1<Float>>,
324    /// State marker
325    _state: PhantomData<State>,
326}
327
328impl PhylogeneticKernel<Untrained> {
329    /// Create a new phylogenetic kernel
330    pub fn new(n_components: usize, tree_depth: usize) -> Self {
331        Self {
332            n_components,
333            tree_depth,
334            use_branch_lengths: true,
335            projection: None,
336            branch_weights: None,
337            _state: PhantomData,
338        }
339    }
340
341    /// Set whether to use branch lengths for weighting
342    pub fn use_branch_lengths(mut self, use_branch_lengths: bool) -> Self {
343        self.use_branch_lengths = use_branch_lengths;
344        self
345    }
346}
347
348impl Default for PhylogeneticKernel<Untrained> {
349    fn default() -> Self {
350        Self::new(100, 5)
351    }
352}
353
354impl Fit<Array2<Float>, ()> for PhylogeneticKernel<Untrained> {
355    type Fitted = PhylogeneticKernel<Trained>;
356
357    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
358        let n_samples = x.nrows();
359        if n_samples == 0 {
360            return Err(SklearsError::InvalidInput(
361                "Input array cannot be empty".to_string(),
362            ));
363        }
364
365        // Feature dimension based on tree structure
366        let feature_dim = 2usize.pow(self.tree_depth as u32);
367
368        let mut rng = thread_rng();
369        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
370
371        // Generate random projection
372        let mut projection = Array2::zeros((feature_dim, self.n_components));
373        for i in 0..feature_dim {
374            for j in 0..self.n_components {
375                projection[[i, j]] = normal.sample(&mut rng);
376            }
377        }
378
379        // Initialize branch length weights (exponentially decaying with depth)
380        let branch_weights = if self.use_branch_lengths {
381            let mut weights = Array1::zeros(self.tree_depth);
382            for i in 0..self.tree_depth {
383                weights[i] = (-(i as Float) * 0.5).exp();
384            }
385            Some(weights)
386        } else {
387            None
388        };
389
390        Ok(PhylogeneticKernel {
391            n_components: self.n_components,
392            tree_depth: self.tree_depth,
393            use_branch_lengths: self.use_branch_lengths,
394            projection: Some(projection),
395            branch_weights,
396            _state: PhantomData,
397        })
398    }
399}
400
401impl Transform<Array2<Float>, Array2<Float>> for PhylogeneticKernel<Trained> {
402    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
403        let n_samples = x.nrows();
404        let n_features = x.ncols();
405
406        if n_samples == 0 {
407            return Err(SklearsError::InvalidInput(
408                "Input array cannot be empty".to_string(),
409            ));
410        }
411
412        let projection = self.projection.as_ref().unwrap();
413        let feature_dim = projection.nrows();
414
415        // Extract phylogenetic features
416        let mut tree_features = Array2::zeros((n_samples, feature_dim));
417
418        for i in 0..n_samples {
419            for j in 0..n_features.min(feature_dim) {
420                let base_value = x[[i, j % n_features]].abs();
421
422                // Apply branch length weighting if enabled
423                if self.use_branch_lengths {
424                    if let Some(weights) = &self.branch_weights {
425                        let depth_idx = j % self.tree_depth;
426                        tree_features[[i, j]] = base_value * weights[depth_idx];
427                    } else {
428                        tree_features[[i, j]] = base_value;
429                    }
430                } else {
431                    tree_features[[i, j]] = base_value;
432                }
433            }
434        }
435
436        // Apply random projection
437        let features = tree_features.dot(projection);
438
439        Ok(features)
440    }
441}
442
443// ============================================================================
444// Metabolic Network Kernel
445// ============================================================================
446
447/// Kernel method for metabolic network and pathway analysis
448///
449/// This kernel analyzes metabolic networks using graph-based features,
450/// pathway similarities, and network topology.
451///
452/// # References
453/// - Borgwardt et al. (2005): "Shortest-path kernels on graphs"
454pub struct MetabolicNetworkKernel<State = Untrained> {
455    /// Number of random features
456    n_components: usize,
457    /// Maximum path length to consider
458    max_path_length: usize,
459    /// Whether to include pathway enrichment
460    use_pathway_enrichment: bool,
461    /// Random projection matrix
462    projection: Option<Array2<Float>>,
463    /// Pathway weights
464    pathway_weights: Option<Array1<Float>>,
465    /// State marker
466    _state: PhantomData<State>,
467}
468
469impl MetabolicNetworkKernel<Untrained> {
470    /// Create a new metabolic network kernel
471    pub fn new(n_components: usize, max_path_length: usize) -> Self {
472        Self {
473            n_components,
474            max_path_length,
475            use_pathway_enrichment: true,
476            projection: None,
477            pathway_weights: None,
478            _state: PhantomData,
479        }
480    }
481
482    /// Set whether to use pathway enrichment features
483    pub fn use_pathway_enrichment(mut self, use_pathway_enrichment: bool) -> Self {
484        self.use_pathway_enrichment = use_pathway_enrichment;
485        self
486    }
487}
488
489impl Default for MetabolicNetworkKernel<Untrained> {
490    fn default() -> Self {
491        Self::new(100, 4)
492    }
493}
494
495impl Fit<Array2<Float>, ()> for MetabolicNetworkKernel<Untrained> {
496    type Fitted = MetabolicNetworkKernel<Trained>;
497
498    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
499        let n_samples = x.nrows();
500        if n_samples == 0 {
501            return Err(SklearsError::InvalidInput(
502                "Input array cannot be empty".to_string(),
503            ));
504        }
505
506        // Feature dimension based on network structure
507        let base_dim = 50; // Network topology features
508        let pathway_dim = if self.use_pathway_enrichment { 20 } else { 0 };
509        let feature_dim = base_dim + pathway_dim;
510
511        let mut rng = thread_rng();
512        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt()).unwrap();
513
514        // Generate random projection
515        let mut projection = Array2::zeros((feature_dim, self.n_components));
516        for i in 0..feature_dim {
517            for j in 0..self.n_components {
518                projection[[i, j]] = normal.sample(&mut rng);
519            }
520        }
521
522        // Initialize pathway enrichment weights
523        let pathway_weights = if self.use_pathway_enrichment {
524            let mut weights = Array1::zeros(pathway_dim);
525            for i in 0..pathway_dim {
526                // Different pathways have different importance
527                weights[i] = 1.0 / (1.0 + (i as Float) * 0.1);
528            }
529            Some(weights)
530        } else {
531            None
532        };
533
534        Ok(MetabolicNetworkKernel {
535            n_components: self.n_components,
536            max_path_length: self.max_path_length,
537            use_pathway_enrichment: self.use_pathway_enrichment,
538            projection: Some(projection),
539            pathway_weights,
540            _state: PhantomData,
541        })
542    }
543}
544
545impl Transform<Array2<Float>, Array2<Float>> for MetabolicNetworkKernel<Trained> {
546    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
547        let n_samples = x.nrows();
548        let n_features = x.ncols();
549
550        if n_samples == 0 {
551            return Err(SklearsError::InvalidInput(
552                "Input array cannot be empty".to_string(),
553            ));
554        }
555
556        let projection = self.projection.as_ref().unwrap();
557        let feature_dim = projection.nrows();
558
559        // Extract network features
560        let mut network_features = Array2::zeros((n_samples, feature_dim));
561
562        for i in 0..n_samples {
563            // Network topology features
564            for j in 0..n_features.min(feature_dim) {
565                network_features[[i, j]] = x[[i, j % n_features]].abs();
566            }
567
568            // Add pathway enrichment features if enabled
569            if self.use_pathway_enrichment {
570                if let Some(weights) = &self.pathway_weights {
571                    let pathway_start = 50;
572                    for (pathway_idx, &weight) in weights.iter().enumerate() {
573                        if pathway_start + pathway_idx < feature_dim {
574                            // Pathway enrichment score weighted by importance
575                            let pathway_value = x[[i, pathway_idx % n_features]].abs() * weight;
576                            network_features[[i, pathway_start + pathway_idx]] = pathway_value;
577                        }
578                    }
579                }
580            }
581        }
582
583        // Apply random projection
584        let features = network_features.dot(projection);
585
586        Ok(features)
587    }
588}
589
590// ============================================================================
591// Multi-Omics Kernel
592// ============================================================================
593
594/// Multi-omics integration method
595#[derive(Debug, Clone, Copy)]
596pub enum OmicsIntegrationMethod {
597    /// Simple concatenation of omics features
598    Concatenation,
599    /// Weighted average based on omics type importance
600    WeightedAverage,
601    /// Cross-omics correlation features
602    CrossCorrelation,
603    /// Multi-view kernel learning
604    MultiViewLearning,
605}
606
607/// Kernel method for multi-omics data integration
608///
609/// This kernel integrates multiple omics data types (genomics, transcriptomics,
610/// proteomics, metabolomics) using various integration strategies.
611///
612/// # References
613/// - Nguyen et al. (2017): "A novel approach for data integration and disease subtyping"
614pub struct MultiOmicsKernel<State = Untrained> {
615    /// Number of random features
616    n_components: usize,
617    /// Number of different omics types
618    n_omics_types: usize,
619    /// Integration method
620    integration_method: OmicsIntegrationMethod,
621    /// Random projection matrices (one per omics type)
622    projections: Option<Vec<Array2<Float>>>,
623    /// Omics type weights
624    omics_weights: Option<Array1<Float>>,
625    /// State marker
626    _state: PhantomData<State>,
627}
628
629impl MultiOmicsKernel<Untrained> {
630    /// Create a new multi-omics kernel
631    pub fn new(n_components: usize, n_omics_types: usize) -> Self {
632        Self {
633            n_components,
634            n_omics_types,
635            integration_method: OmicsIntegrationMethod::WeightedAverage,
636            projections: None,
637            omics_weights: None,
638            _state: PhantomData,
639        }
640    }
641
642    /// Set the integration method
643    pub fn integration_method(mut self, method: OmicsIntegrationMethod) -> Self {
644        self.integration_method = method;
645        self
646    }
647}
648
649impl Default for MultiOmicsKernel<Untrained> {
650    fn default() -> Self {
651        Self::new(100, 3)
652    }
653}
654
655impl Fit<Array2<Float>, ()> for MultiOmicsKernel<Untrained> {
656    type Fitted = MultiOmicsKernel<Trained>;
657
658    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
659        let n_samples = x.nrows();
660        let n_features = x.ncols();
661
662        if n_samples == 0 {
663            return Err(SklearsError::InvalidInput(
664                "Input array cannot be empty".to_string(),
665            ));
666        }
667
668        // Assume features are divided equally among omics types
669        let features_per_omics = n_features / self.n_omics_types;
670
671        let mut rng = thread_rng();
672
673        // Generate separate projection for each omics type
674        let mut projections = Vec::new();
675        for _ in 0..self.n_omics_types {
676            let normal = Normal::new(0.0, 1.0 / (features_per_omics as Float).sqrt()).unwrap();
677            let mut projection = Array2::zeros((features_per_omics, self.n_components));
678
679            for i in 0..features_per_omics {
680                for j in 0..self.n_components {
681                    projection[[i, j]] = normal.sample(&mut rng);
682                }
683            }
684            projections.push(projection);
685        }
686
687        // Initialize omics type weights (different weights for genomics, transcriptomics, proteomics, etc.)
688        let mut omics_weights = Array1::zeros(self.n_omics_types);
689        for i in 0..self.n_omics_types {
690            // Decreasing importance: genomics > transcriptomics > proteomics > metabolomics
691            omics_weights[i] = 1.0 / (1.0 + (i as Float) * 0.2);
692        }
693
694        Ok(MultiOmicsKernel {
695            n_components: self.n_components,
696            n_omics_types: self.n_omics_types,
697            integration_method: self.integration_method,
698            projections: Some(projections),
699            omics_weights: Some(omics_weights),
700            _state: PhantomData,
701        })
702    }
703}
704
705impl Transform<Array2<Float>, Array2<Float>> for MultiOmicsKernel<Trained> {
706    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
707        let n_samples = x.nrows();
708        let n_features = x.ncols();
709
710        if n_samples == 0 {
711            return Err(SklearsError::InvalidInput(
712                "Input array cannot be empty".to_string(),
713            ));
714        }
715
716        let projections = self.projections.as_ref().unwrap();
717        let omics_weights = self.omics_weights.as_ref().unwrap();
718        let features_per_omics = n_features / self.n_omics_types;
719
720        let mut result = Array2::zeros((n_samples, self.n_components));
721
722        match self.integration_method {
723            OmicsIntegrationMethod::Concatenation => {
724                // Project each omics type separately and concatenate
725                // For simplicity, we average instead of concatenating to maintain dimension
726                for omics_idx in 0..self.n_omics_types {
727                    let start_idx = omics_idx * features_per_omics;
728                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
729
730                    if start_idx < n_features {
731                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
732                        for i in 0..n_samples {
733                            for j in 0..(end_idx - start_idx) {
734                                omics_data[[i, j]] = x[[i, start_idx + j]];
735                            }
736                        }
737                        let omics_features = omics_data.dot(&projections[omics_idx]);
738                        result += &omics_features;
739                    }
740                }
741                result /= self.n_omics_types as Float;
742            }
743            OmicsIntegrationMethod::WeightedAverage => {
744                // Weighted combination of omics-specific features
745                for omics_idx in 0..self.n_omics_types {
746                    let start_idx = omics_idx * features_per_omics;
747                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
748
749                    if start_idx < n_features {
750                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
751                        for i in 0..n_samples {
752                            for j in 0..(end_idx - start_idx) {
753                                omics_data[[i, j]] = x[[i, start_idx + j]];
754                            }
755                        }
756                        let omics_features = omics_data.dot(&projections[omics_idx]);
757                        let weight = omics_weights[omics_idx];
758                        result += &(omics_features * weight);
759                    }
760                }
761                // Normalize by sum of weights
762                let weight_sum: Float = omics_weights.sum();
763                result /= weight_sum;
764            }
765            OmicsIntegrationMethod::CrossCorrelation => {
766                // Include cross-correlation between omics types
767                for omics_idx in 0..self.n_omics_types {
768                    let start_idx = omics_idx * features_per_omics;
769                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
770
771                    if start_idx < n_features {
772                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
773                        for i in 0..n_samples {
774                            for j in 0..(end_idx - start_idx) {
775                                omics_data[[i, j]] = x[[i, start_idx + j]];
776                            }
777                        }
778                        let mut omics_features = omics_data.dot(&projections[omics_idx]);
779
780                        // Add cross-correlation with other omics types
781                        for other_idx in 0..self.n_omics_types {
782                            if other_idx != omics_idx {
783                                let other_start = other_idx * features_per_omics;
784                                let other_end =
785                                    ((other_idx + 1) * features_per_omics).min(n_features);
786
787                                if other_start < n_features {
788                                    let mut other_data =
789                                        Array2::zeros((n_samples, other_end - other_start));
790                                    for i in 0..n_samples {
791                                        for j in 0..(other_end - other_start) {
792                                            other_data[[i, j]] = x[[i, other_start + j]];
793                                        }
794                                    }
795                                    let other_features = other_data.dot(&projections[other_idx]);
796                                    // Element-wise multiplication for cross-correlation
797                                    omics_features += &(&other_features * 0.1);
798                                }
799                            }
800                        }
801
802                        result += &omics_features;
803                    }
804                }
805                result /= self.n_omics_types as Float;
806            }
807            OmicsIntegrationMethod::MultiViewLearning => {
808                // Multi-view learning with view-specific and shared features
809                let mut view_features = Vec::new();
810
811                for omics_idx in 0..self.n_omics_types {
812                    let start_idx = omics_idx * features_per_omics;
813                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
814
815                    if start_idx < n_features {
816                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
817                        for i in 0..n_samples {
818                            for j in 0..(end_idx - start_idx) {
819                                omics_data[[i, j]] = x[[i, start_idx + j]];
820                            }
821                        }
822                        let omics_features = omics_data.dot(&projections[omics_idx]);
823                        view_features.push(omics_features);
824                    }
825                }
826
827                // Combine views using weighted average
828                for (idx, features) in view_features.iter().enumerate() {
829                    let weight = omics_weights[idx];
830                    result += &(features * weight);
831                }
832                let weight_sum: Float = omics_weights.sum();
833                result /= weight_sum;
834            }
835        }
836
837        Ok(result)
838    }
839}
840
841// ============================================================================
842// Tests
843// ============================================================================
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use scirs2_core::ndarray::array;
849
850    #[test]
851    fn test_genomic_kernel_basic() {
852        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
853
854        let kernel = GenomicKernel::new(3, 50);
855        let fitted = kernel.fit(&x, &()).unwrap();
856        let features = fitted.transform(&x).unwrap();
857
858        assert_eq!(features.shape(), &[3, 50]);
859    }
860
861    #[test]
862    fn test_genomic_kernel_normalization() {
863        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
864
865        let kernel = GenomicKernel::new(3, 30).normalize(false);
866        let fitted = kernel.fit(&x, &()).unwrap();
867        let features = fitted.transform(&x).unwrap();
868
869        assert_eq!(features.shape(), &[2, 30]);
870    }
871
872    #[test]
873    fn test_protein_kernel_basic() {
874        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
875
876        let kernel = ProteinKernel::new(2, 40);
877        let fitted = kernel.fit(&x, &()).unwrap();
878        let features = fitted.transform(&x).unwrap();
879
880        assert_eq!(features.shape(), &[2, 40]);
881    }
882
883    #[test]
884    fn test_protein_kernel_properties() {
885        let x = array![[1.0, 2.0], [3.0, 4.0]];
886
887        let kernel = ProteinKernel::new(2, 30).use_properties(true);
888        let fitted = kernel.fit(&x, &()).unwrap();
889        let features = fitted.transform(&x).unwrap();
890
891        assert_eq!(features.shape(), &[2, 30]);
892        assert!(fitted.property_weights.is_some());
893    }
894
895    #[test]
896    fn test_phylogenetic_kernel_basic() {
897        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
898
899        let kernel = PhylogeneticKernel::new(50, 4);
900        let fitted = kernel.fit(&x, &()).unwrap();
901        let features = fitted.transform(&x).unwrap();
902
903        assert_eq!(features.shape(), &[2, 50]);
904    }
905
906    #[test]
907    fn test_phylogenetic_kernel_branch_lengths() {
908        let x = array![[1.0, 2.0], [3.0, 4.0]];
909
910        let kernel = PhylogeneticKernel::new(40, 3).use_branch_lengths(true);
911        let fitted = kernel.fit(&x, &()).unwrap();
912        let features = fitted.transform(&x).unwrap();
913
914        assert_eq!(features.shape(), &[2, 40]);
915        assert!(fitted.branch_weights.is_some());
916    }
917
918    #[test]
919    fn test_metabolic_network_kernel_basic() {
920        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
921
922        let kernel = MetabolicNetworkKernel::new(60, 3);
923        let fitted = kernel.fit(&x, &()).unwrap();
924        let features = fitted.transform(&x).unwrap();
925
926        assert_eq!(features.shape(), &[2, 60]);
927    }
928
929    #[test]
930    fn test_metabolic_network_kernel_pathways() {
931        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
932
933        let kernel = MetabolicNetworkKernel::new(50, 3).use_pathway_enrichment(true);
934        let fitted = kernel.fit(&x, &()).unwrap();
935        let features = fitted.transform(&x).unwrap();
936
937        assert_eq!(features.shape(), &[2, 50]);
938        assert!(fitted.pathway_weights.is_some());
939    }
940
941    #[test]
942    fn test_multi_omics_kernel_basic() {
943        let x = array![
944            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
945            [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
946        ];
947
948        let kernel = MultiOmicsKernel::new(40, 3);
949        let fitted = kernel.fit(&x, &()).unwrap();
950        let features = fitted.transform(&x).unwrap();
951
952        assert_eq!(features.shape(), &[2, 40]);
953    }
954
955    #[test]
956    fn test_multi_omics_integration_methods() {
957        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
958
959        let methods = vec![
960            OmicsIntegrationMethod::Concatenation,
961            OmicsIntegrationMethod::WeightedAverage,
962            OmicsIntegrationMethod::CrossCorrelation,
963            OmicsIntegrationMethod::MultiViewLearning,
964        ];
965
966        for method in methods {
967            let kernel = MultiOmicsKernel::new(30, 2).integration_method(method);
968            let fitted = kernel.fit(&x, &()).unwrap();
969            let features = fitted.transform(&x).unwrap();
970            assert_eq!(features.shape(), &[2, 30]);
971        }
972    }
973
974    #[test]
975    fn test_empty_input_error() {
976        let x_empty: Array2<Float> = Array2::zeros((0, 3));
977
978        let kernel = GenomicKernel::new(3, 50);
979        assert!(kernel.fit(&x_empty, &()).is_err());
980
981        let kernel2 = ProteinKernel::new(2, 40);
982        assert!(kernel2.fit(&x_empty, &()).is_err());
983    }
984}