Skip to main content

sklears_kernel_approximation/
bioinformatics_kernels.rs

1//! Bioinformatics Kernel Methods
2//!
3//! This module implements kernel methods for bioinformatics applications,
4//! including genomic analysis, protein structure, phylogenetic analysis,
5//! metabolic networks, and multi-omics integration.
6//!
7//! # References
8//! - Leslie et al. (2004): "Mismatch string kernels for discriminative protein classification"
9//! - Shawe-Taylor & Cristianini (2004): "Kernel Methods for Pattern Analysis"
10//! - Vert et al. (2004): "A primer on kernel methods in computational biology"
11//! - Borgwardt et al. (2005): "Protein function prediction via graph kernels"
12
13use scirs2_core::ndarray::{Array1, Array2};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::{thread_rng, Distribution};
16use sklears_core::{
17    error::{Result, SklearsError},
18    prelude::{Fit, Transform},
19    traits::{Trained, Untrained},
20    types::Float,
21};
22use std::collections::HashMap;
23use std::marker::PhantomData;
24
25// ============================================================================
26// Genomic Kernel
27// ============================================================================
28
29/// Kernel method for genomic sequence analysis using k-mer features
30///
31/// This kernel approximates similarity between genomic sequences (DNA/RNA)
32/// using k-mer (k-length subsequence) counting and random feature projection.
33///
34/// # References
35/// - Leslie et al. (2002): "The spectrum kernel: A string kernel for SVM protein classification"
36pub struct GenomicKernel<State = Untrained> {
37    /// K-mer length (typically 3-8 for DNA)
38    k: usize,
39    /// Number of random features for kernel approximation
40    n_components: usize,
41    /// Whether to normalize k-mer counts
42    normalize: bool,
43    /// Random projection matrix (for trained state)
44    projection: Option<Array2<Float>>,
45    /// K-mer vocabulary mapping (for trained state)
46    kmer_vocab: Option<HashMap<String, usize>>,
47    /// State marker
48    _state: PhantomData<State>,
49}
50
51impl GenomicKernel<Untrained> {
52    /// Create a new genomic kernel with specified k-mer length
53    pub fn new(k: usize, n_components: usize) -> Self {
54        Self {
55            k,
56            n_components,
57            normalize: true,
58            projection: None,
59            kmer_vocab: None,
60            _state: PhantomData,
61        }
62    }
63
64    /// Set whether to normalize k-mer counts
65    pub fn normalize(mut self, normalize: bool) -> Self {
66        self.normalize = normalize;
67        self
68    }
69}
70
71impl Default for GenomicKernel<Untrained> {
72    fn default() -> Self {
73        Self::new(3, 100)
74    }
75}
76
77impl Fit<Array2<Float>, ()> for GenomicKernel<Untrained> {
78    type Fitted = GenomicKernel<Trained>;
79
80    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
81        let n_samples = x.nrows();
82        if n_samples == 0 {
83            return Err(SklearsError::InvalidInput(
84                "Input array cannot be empty".to_string(),
85            ));
86        }
87
88        // Build k-mer vocabulary from data
89        // In practice, DNA has 4 bases (A,C,G,T), so vocab size is 4^k
90        let vocab_size = 4usize.pow(self.k as u32);
91        let mut kmer_vocab = HashMap::new();
92
93        // Create synthetic k-mer vocabulary (in real use, extract from sequences)
94        for i in 0..vocab_size {
95            let kmer = format!("kmer_{}", i);
96            kmer_vocab.insert(kmer, i);
97        }
98
99        // Generate random projection matrix for dimensionality reduction
100        let mut rng = thread_rng();
101        let normal =
102            Normal::new(0.0, 1.0 / (vocab_size as Float).sqrt()).expect("operation should succeed");
103
104        let mut projection = Array2::zeros((vocab_size, self.n_components));
105        for i in 0..vocab_size {
106            for j in 0..self.n_components {
107                projection[[i, j]] = normal.sample(&mut rng);
108            }
109        }
110
111        Ok(GenomicKernel {
112            k: self.k,
113            n_components: self.n_components,
114            normalize: self.normalize,
115            projection: Some(projection),
116            kmer_vocab: Some(kmer_vocab),
117            _state: PhantomData,
118        })
119    }
120}
121
122impl Transform<Array2<Float>, Array2<Float>> for GenomicKernel<Trained> {
123    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
124        let n_samples = x.nrows();
125        let n_features = x.ncols();
126
127        if n_samples == 0 {
128            return Err(SklearsError::InvalidInput(
129                "Input array cannot be empty".to_string(),
130            ));
131        }
132
133        let projection = self.projection.as_ref().expect("operation should succeed");
134        let vocab_size = projection.nrows();
135
136        // Extract k-mer features from input (simulated)
137        let mut kmer_counts = Array2::zeros((n_samples, vocab_size));
138
139        for i in 0..n_samples {
140            for j in 0..n_features.min(vocab_size) {
141                // Simulated k-mer counting (in real use, count k-mers from sequences)
142                kmer_counts[[i, j]] = x[[i, j % n_features]].abs();
143            }
144
145            // Normalize if requested
146            if self.normalize {
147                let row_sum: Float = kmer_counts.row(i).sum();
148                if row_sum > 0.0 {
149                    for j in 0..vocab_size {
150                        kmer_counts[[i, j]] /= row_sum;
151                    }
152                }
153            }
154        }
155
156        // Apply random projection
157        let features = kmer_counts.dot(projection);
158
159        Ok(features)
160    }
161}
162
163// ============================================================================
164// Protein Kernel
165// ============================================================================
166
167/// Kernel method for protein sequence analysis with physicochemical properties
168///
169/// This kernel incorporates amino acid substitution matrices and physicochemical
170/// properties for protein sequence comparison.
171///
172/// # References
173/// - Henikoff & Henikoff (1992): "Amino acid substitution matrices from protein blocks"
174pub struct ProteinKernel<State = Untrained> {
175    /// Length of amino acid patterns to extract
176    pattern_length: usize,
177    /// Number of random features
178    n_components: usize,
179    /// Whether to use physicochemical properties
180    use_properties: bool,
181    /// Random projection matrix
182    projection: Option<Array2<Float>>,
183    /// Amino acid property weights
184    property_weights: Option<Array1<Float>>,
185    /// State marker
186    _state: PhantomData<State>,
187}
188
189impl ProteinKernel<Untrained> {
190    /// Create a new protein kernel
191    pub fn new(pattern_length: usize, n_components: usize) -> Self {
192        Self {
193            pattern_length,
194            n_components,
195            use_properties: true,
196            projection: None,
197            property_weights: None,
198            _state: PhantomData,
199        }
200    }
201
202    /// Set whether to use physicochemical properties
203    pub fn use_properties(mut self, use_properties: bool) -> Self {
204        self.use_properties = use_properties;
205        self
206    }
207}
208
209impl Default for ProteinKernel<Untrained> {
210    fn default() -> Self {
211        Self::new(3, 100)
212    }
213}
214
215impl Fit<Array2<Float>, ()> for ProteinKernel<Untrained> {
216    type Fitted = ProteinKernel<Trained>;
217
218    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
219        let n_samples = x.nrows();
220        let _n_features = x.ncols();
221
222        if n_samples == 0 {
223            return Err(SklearsError::InvalidInput(
224                "Input array cannot be empty".to_string(),
225            ));
226        }
227
228        // 20 amino acids + physicochemical properties
229        let feature_dim = if self.use_properties { 20 + 5 } else { 20 };
230
231        let mut rng = thread_rng();
232        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
233            .expect("operation should succeed");
234
235        // Generate random projection
236        let mut projection = Array2::zeros((feature_dim * self.pattern_length, self.n_components));
237        for i in 0..(feature_dim * self.pattern_length) {
238            for j in 0..self.n_components {
239                projection[[i, j]] = normal.sample(&mut rng);
240            }
241        }
242
243        // Initialize physicochemical property weights (hydrophobicity, charge, size, polarity, aromaticity)
244        let property_weights = if self.use_properties {
245            Some(Array1::from_vec(vec![1.0, 0.8, 0.6, 0.7, 0.5]))
246        } else {
247            None
248        };
249
250        Ok(ProteinKernel {
251            pattern_length: self.pattern_length,
252            n_components: self.n_components,
253            use_properties: self.use_properties,
254            projection: Some(projection),
255            property_weights,
256            _state: PhantomData,
257        })
258    }
259}
260
261impl Transform<Array2<Float>, Array2<Float>> for ProteinKernel<Trained> {
262    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
263        let n_samples = x.nrows();
264        let n_features = x.ncols();
265
266        if n_samples == 0 {
267            return Err(SklearsError::InvalidInput(
268                "Input array cannot be empty".to_string(),
269            ));
270        }
271
272        let projection = self.projection.as_ref().expect("operation should succeed");
273        let feature_dim = projection.nrows();
274
275        // Extract protein features
276        let mut protein_features = Array2::zeros((n_samples, feature_dim));
277
278        for i in 0..n_samples {
279            for j in 0..n_features.min(feature_dim) {
280                // Simulated amino acid encoding with physicochemical properties
281                let aa_value = x[[i, j % n_features]].abs();
282                protein_features[[i, j]] = aa_value;
283
284                // Add physicochemical property contributions if enabled
285                if self.use_properties && j + 20 < feature_dim {
286                    if let Some(weights) = &self.property_weights {
287                        for (prop_idx, &weight) in weights.iter().enumerate() {
288                            if j + 20 + prop_idx < feature_dim {
289                                protein_features[[i, j + 20 + prop_idx]] = aa_value * weight;
290                            }
291                        }
292                    }
293                }
294            }
295        }
296
297        // Apply random projection
298        let features = protein_features.dot(projection);
299
300        Ok(features)
301    }
302}
303
304// ============================================================================
305// Phylogenetic Kernel
306// ============================================================================
307
308/// Kernel method for phylogenetic analysis using evolutionary distances
309///
310/// This kernel computes features based on phylogenetic tree structure and
311/// evolutionary distances between species.
312///
313/// # References
314/// - Vert (2002): "A tree kernel to analyze phylogenetic profiles"
315pub struct PhylogeneticKernel<State = Untrained> {
316    /// Number of random features
317    n_components: usize,
318    /// Tree depth to consider
319    tree_depth: usize,
320    /// Whether to weight by branch length
321    use_branch_lengths: bool,
322    /// Random projection matrix
323    projection: Option<Array2<Float>>,
324    /// Branch length weights
325    branch_weights: Option<Array1<Float>>,
326    /// State marker
327    _state: PhantomData<State>,
328}
329
330impl PhylogeneticKernel<Untrained> {
331    /// Create a new phylogenetic kernel
332    pub fn new(n_components: usize, tree_depth: usize) -> Self {
333        Self {
334            n_components,
335            tree_depth,
336            use_branch_lengths: true,
337            projection: None,
338            branch_weights: None,
339            _state: PhantomData,
340        }
341    }
342
343    /// Set whether to use branch lengths for weighting
344    pub fn use_branch_lengths(mut self, use_branch_lengths: bool) -> Self {
345        self.use_branch_lengths = use_branch_lengths;
346        self
347    }
348}
349
350impl Default for PhylogeneticKernel<Untrained> {
351    fn default() -> Self {
352        Self::new(100, 5)
353    }
354}
355
356impl Fit<Array2<Float>, ()> for PhylogeneticKernel<Untrained> {
357    type Fitted = PhylogeneticKernel<Trained>;
358
359    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
360        let n_samples = x.nrows();
361        if n_samples == 0 {
362            return Err(SklearsError::InvalidInput(
363                "Input array cannot be empty".to_string(),
364            ));
365        }
366
367        // Feature dimension based on tree structure
368        let feature_dim = 2usize.pow(self.tree_depth as u32);
369
370        let mut rng = thread_rng();
371        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
372            .expect("operation should succeed");
373
374        // Generate random projection
375        let mut projection = Array2::zeros((feature_dim, self.n_components));
376        for i in 0..feature_dim {
377            for j in 0..self.n_components {
378                projection[[i, j]] = normal.sample(&mut rng);
379            }
380        }
381
382        // Initialize branch length weights (exponentially decaying with depth)
383        let branch_weights = if self.use_branch_lengths {
384            let mut weights = Array1::zeros(self.tree_depth);
385            for i in 0..self.tree_depth {
386                weights[i] = (-(i as Float) * 0.5).exp();
387            }
388            Some(weights)
389        } else {
390            None
391        };
392
393        Ok(PhylogeneticKernel {
394            n_components: self.n_components,
395            tree_depth: self.tree_depth,
396            use_branch_lengths: self.use_branch_lengths,
397            projection: Some(projection),
398            branch_weights,
399            _state: PhantomData,
400        })
401    }
402}
403
404impl Transform<Array2<Float>, Array2<Float>> for PhylogeneticKernel<Trained> {
405    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
406        let n_samples = x.nrows();
407        let n_features = x.ncols();
408
409        if n_samples == 0 {
410            return Err(SklearsError::InvalidInput(
411                "Input array cannot be empty".to_string(),
412            ));
413        }
414
415        let projection = self.projection.as_ref().expect("operation should succeed");
416        let feature_dim = projection.nrows();
417
418        // Extract phylogenetic features
419        let mut tree_features = Array2::zeros((n_samples, feature_dim));
420
421        for i in 0..n_samples {
422            for j in 0..n_features.min(feature_dim) {
423                let base_value = x[[i, j % n_features]].abs();
424
425                // Apply branch length weighting if enabled
426                if self.use_branch_lengths {
427                    if let Some(weights) = &self.branch_weights {
428                        let depth_idx = j % self.tree_depth;
429                        tree_features[[i, j]] = base_value * weights[depth_idx];
430                    } else {
431                        tree_features[[i, j]] = base_value;
432                    }
433                } else {
434                    tree_features[[i, j]] = base_value;
435                }
436            }
437        }
438
439        // Apply random projection
440        let features = tree_features.dot(projection);
441
442        Ok(features)
443    }
444}
445
446// ============================================================================
447// Metabolic Network Kernel
448// ============================================================================
449
450/// Kernel method for metabolic network and pathway analysis
451///
452/// This kernel analyzes metabolic networks using graph-based features,
453/// pathway similarities, and network topology.
454///
455/// # References
456/// - Borgwardt et al. (2005): "Shortest-path kernels on graphs"
457pub struct MetabolicNetworkKernel<State = Untrained> {
458    /// Number of random features
459    n_components: usize,
460    /// Maximum path length to consider
461    max_path_length: usize,
462    /// Whether to include pathway enrichment
463    use_pathway_enrichment: bool,
464    /// Random projection matrix
465    projection: Option<Array2<Float>>,
466    /// Pathway weights
467    pathway_weights: Option<Array1<Float>>,
468    /// State marker
469    _state: PhantomData<State>,
470}
471
472impl MetabolicNetworkKernel<Untrained> {
473    /// Create a new metabolic network kernel
474    pub fn new(n_components: usize, max_path_length: usize) -> Self {
475        Self {
476            n_components,
477            max_path_length,
478            use_pathway_enrichment: true,
479            projection: None,
480            pathway_weights: None,
481            _state: PhantomData,
482        }
483    }
484
485    /// Set whether to use pathway enrichment features
486    pub fn use_pathway_enrichment(mut self, use_pathway_enrichment: bool) -> Self {
487        self.use_pathway_enrichment = use_pathway_enrichment;
488        self
489    }
490}
491
492impl Default for MetabolicNetworkKernel<Untrained> {
493    fn default() -> Self {
494        Self::new(100, 4)
495    }
496}
497
498impl Fit<Array2<Float>, ()> for MetabolicNetworkKernel<Untrained> {
499    type Fitted = MetabolicNetworkKernel<Trained>;
500
501    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
502        let n_samples = x.nrows();
503        if n_samples == 0 {
504            return Err(SklearsError::InvalidInput(
505                "Input array cannot be empty".to_string(),
506            ));
507        }
508
509        // Feature dimension based on network structure
510        let base_dim = 50; // Network topology features
511        let pathway_dim = if self.use_pathway_enrichment { 20 } else { 0 };
512        let feature_dim = base_dim + pathway_dim;
513
514        let mut rng = thread_rng();
515        let normal = Normal::new(0.0, 1.0 / (feature_dim as Float).sqrt())
516            .expect("operation should succeed");
517
518        // Generate random projection
519        let mut projection = Array2::zeros((feature_dim, self.n_components));
520        for i in 0..feature_dim {
521            for j in 0..self.n_components {
522                projection[[i, j]] = normal.sample(&mut rng);
523            }
524        }
525
526        // Initialize pathway enrichment weights
527        let pathway_weights = if self.use_pathway_enrichment {
528            let mut weights = Array1::zeros(pathway_dim);
529            for i in 0..pathway_dim {
530                // Different pathways have different importance
531                weights[i] = 1.0 / (1.0 + (i as Float) * 0.1);
532            }
533            Some(weights)
534        } else {
535            None
536        };
537
538        Ok(MetabolicNetworkKernel {
539            n_components: self.n_components,
540            max_path_length: self.max_path_length,
541            use_pathway_enrichment: self.use_pathway_enrichment,
542            projection: Some(projection),
543            pathway_weights,
544            _state: PhantomData,
545        })
546    }
547}
548
549impl Transform<Array2<Float>, Array2<Float>> for MetabolicNetworkKernel<Trained> {
550    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
551        let n_samples = x.nrows();
552        let n_features = x.ncols();
553
554        if n_samples == 0 {
555            return Err(SklearsError::InvalidInput(
556                "Input array cannot be empty".to_string(),
557            ));
558        }
559
560        let projection = self.projection.as_ref().expect("operation should succeed");
561        let feature_dim = projection.nrows();
562
563        // Extract network features
564        let mut network_features = Array2::zeros((n_samples, feature_dim));
565
566        for i in 0..n_samples {
567            // Network topology features
568            for j in 0..n_features.min(feature_dim) {
569                network_features[[i, j]] = x[[i, j % n_features]].abs();
570            }
571
572            // Add pathway enrichment features if enabled
573            if self.use_pathway_enrichment {
574                if let Some(weights) = &self.pathway_weights {
575                    let pathway_start = 50;
576                    for (pathway_idx, &weight) in weights.iter().enumerate() {
577                        if pathway_start + pathway_idx < feature_dim {
578                            // Pathway enrichment score weighted by importance
579                            let pathway_value = x[[i, pathway_idx % n_features]].abs() * weight;
580                            network_features[[i, pathway_start + pathway_idx]] = pathway_value;
581                        }
582                    }
583                }
584            }
585        }
586
587        // Apply random projection
588        let features = network_features.dot(projection);
589
590        Ok(features)
591    }
592}
593
594// ============================================================================
595// Multi-Omics Kernel
596// ============================================================================
597
598/// Multi-omics integration method
599#[derive(Debug, Clone, Copy)]
600pub enum OmicsIntegrationMethod {
601    /// Simple concatenation of omics features
602    Concatenation,
603    /// Weighted average based on omics type importance
604    WeightedAverage,
605    /// Cross-omics correlation features
606    CrossCorrelation,
607    /// Multi-view kernel learning
608    MultiViewLearning,
609}
610
611/// Kernel method for multi-omics data integration
612///
613/// This kernel integrates multiple omics data types (genomics, transcriptomics,
614/// proteomics, metabolomics) using various integration strategies.
615///
616/// # References
617/// - Nguyen et al. (2017): "A novel approach for data integration and disease subtyping"
618pub struct MultiOmicsKernel<State = Untrained> {
619    /// Number of random features
620    n_components: usize,
621    /// Number of different omics types
622    n_omics_types: usize,
623    /// Integration method
624    integration_method: OmicsIntegrationMethod,
625    /// Random projection matrices (one per omics type)
626    projections: Option<Vec<Array2<Float>>>,
627    /// Omics type weights
628    omics_weights: Option<Array1<Float>>,
629    /// State marker
630    _state: PhantomData<State>,
631}
632
633impl MultiOmicsKernel<Untrained> {
634    /// Create a new multi-omics kernel
635    pub fn new(n_components: usize, n_omics_types: usize) -> Self {
636        Self {
637            n_components,
638            n_omics_types,
639            integration_method: OmicsIntegrationMethod::WeightedAverage,
640            projections: None,
641            omics_weights: None,
642            _state: PhantomData,
643        }
644    }
645
646    /// Set the integration method
647    pub fn integration_method(mut self, method: OmicsIntegrationMethod) -> Self {
648        self.integration_method = method;
649        self
650    }
651}
652
653impl Default for MultiOmicsKernel<Untrained> {
654    fn default() -> Self {
655        Self::new(100, 3)
656    }
657}
658
659impl Fit<Array2<Float>, ()> for MultiOmicsKernel<Untrained> {
660    type Fitted = MultiOmicsKernel<Trained>;
661
662    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
663        let n_samples = x.nrows();
664        let n_features = x.ncols();
665
666        if n_samples == 0 {
667            return Err(SklearsError::InvalidInput(
668                "Input array cannot be empty".to_string(),
669            ));
670        }
671
672        // Assume features are divided equally among omics types
673        let features_per_omics = n_features / self.n_omics_types;
674
675        let mut rng = thread_rng();
676
677        // Generate separate projection for each omics type
678        let mut projections = Vec::new();
679        for _ in 0..self.n_omics_types {
680            let normal = Normal::new(0.0, 1.0 / (features_per_omics as Float).sqrt())
681                .expect("operation should succeed");
682            let mut projection = Array2::zeros((features_per_omics, self.n_components));
683
684            for i in 0..features_per_omics {
685                for j in 0..self.n_components {
686                    projection[[i, j]] = normal.sample(&mut rng);
687                }
688            }
689            projections.push(projection);
690        }
691
692        // Initialize omics type weights (different weights for genomics, transcriptomics, proteomics, etc.)
693        let mut omics_weights = Array1::zeros(self.n_omics_types);
694        for i in 0..self.n_omics_types {
695            // Decreasing importance: genomics > transcriptomics > proteomics > metabolomics
696            omics_weights[i] = 1.0 / (1.0 + (i as Float) * 0.2);
697        }
698
699        Ok(MultiOmicsKernel {
700            n_components: self.n_components,
701            n_omics_types: self.n_omics_types,
702            integration_method: self.integration_method,
703            projections: Some(projections),
704            omics_weights: Some(omics_weights),
705            _state: PhantomData,
706        })
707    }
708}
709
710impl Transform<Array2<Float>, Array2<Float>> for MultiOmicsKernel<Trained> {
711    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
712        let n_samples = x.nrows();
713        let n_features = x.ncols();
714
715        if n_samples == 0 {
716            return Err(SklearsError::InvalidInput(
717                "Input array cannot be empty".to_string(),
718            ));
719        }
720
721        let projections = self.projections.as_ref().expect("operation should succeed");
722        let omics_weights = self
723            .omics_weights
724            .as_ref()
725            .expect("operation should succeed");
726        let features_per_omics = n_features / self.n_omics_types;
727
728        let mut result = Array2::zeros((n_samples, self.n_components));
729
730        match self.integration_method {
731            OmicsIntegrationMethod::Concatenation => {
732                // Project each omics type separately and concatenate
733                // For simplicity, we average instead of concatenating to maintain dimension
734                for omics_idx in 0..self.n_omics_types {
735                    let start_idx = omics_idx * features_per_omics;
736                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
737
738                    if start_idx < n_features {
739                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
740                        for i in 0..n_samples {
741                            for j in 0..(end_idx - start_idx) {
742                                omics_data[[i, j]] = x[[i, start_idx + j]];
743                            }
744                        }
745                        let omics_features = omics_data.dot(&projections[omics_idx]);
746                        result += &omics_features;
747                    }
748                }
749                result /= self.n_omics_types as Float;
750            }
751            OmicsIntegrationMethod::WeightedAverage => {
752                // Weighted combination of omics-specific features
753                for omics_idx in 0..self.n_omics_types {
754                    let start_idx = omics_idx * features_per_omics;
755                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
756
757                    if start_idx < n_features {
758                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
759                        for i in 0..n_samples {
760                            for j in 0..(end_idx - start_idx) {
761                                omics_data[[i, j]] = x[[i, start_idx + j]];
762                            }
763                        }
764                        let omics_features = omics_data.dot(&projections[omics_idx]);
765                        let weight = omics_weights[omics_idx];
766                        result += &(omics_features * weight);
767                    }
768                }
769                // Normalize by sum of weights
770                let weight_sum: Float = omics_weights.sum();
771                result /= weight_sum;
772            }
773            OmicsIntegrationMethod::CrossCorrelation => {
774                // Include cross-correlation between omics types
775                for omics_idx in 0..self.n_omics_types {
776                    let start_idx = omics_idx * features_per_omics;
777                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
778
779                    if start_idx < n_features {
780                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
781                        for i in 0..n_samples {
782                            for j in 0..(end_idx - start_idx) {
783                                omics_data[[i, j]] = x[[i, start_idx + j]];
784                            }
785                        }
786                        let mut omics_features = omics_data.dot(&projections[omics_idx]);
787
788                        // Add cross-correlation with other omics types
789                        for other_idx in 0..self.n_omics_types {
790                            if other_idx != omics_idx {
791                                let other_start = other_idx * features_per_omics;
792                                let other_end =
793                                    ((other_idx + 1) * features_per_omics).min(n_features);
794
795                                if other_start < n_features {
796                                    let mut other_data =
797                                        Array2::zeros((n_samples, other_end - other_start));
798                                    for i in 0..n_samples {
799                                        for j in 0..(other_end - other_start) {
800                                            other_data[[i, j]] = x[[i, other_start + j]];
801                                        }
802                                    }
803                                    let other_features = other_data.dot(&projections[other_idx]);
804                                    // Element-wise multiplication for cross-correlation
805                                    omics_features += &(&other_features * 0.1);
806                                }
807                            }
808                        }
809
810                        result += &omics_features;
811                    }
812                }
813                result /= self.n_omics_types as Float;
814            }
815            OmicsIntegrationMethod::MultiViewLearning => {
816                // Multi-view learning with view-specific and shared features
817                let mut view_features = Vec::new();
818
819                for omics_idx in 0..self.n_omics_types {
820                    let start_idx = omics_idx * features_per_omics;
821                    let end_idx = ((omics_idx + 1) * features_per_omics).min(n_features);
822
823                    if start_idx < n_features {
824                        let mut omics_data = Array2::zeros((n_samples, end_idx - start_idx));
825                        for i in 0..n_samples {
826                            for j in 0..(end_idx - start_idx) {
827                                omics_data[[i, j]] = x[[i, start_idx + j]];
828                            }
829                        }
830                        let omics_features = omics_data.dot(&projections[omics_idx]);
831                        view_features.push(omics_features);
832                    }
833                }
834
835                // Combine views using weighted average
836                for (idx, features) in view_features.iter().enumerate() {
837                    let weight = omics_weights[idx];
838                    result += &(features * weight);
839                }
840                let weight_sum: Float = omics_weights.sum();
841                result /= weight_sum;
842            }
843        }
844
845        Ok(result)
846    }
847}
848
849// ============================================================================
850// Tests
851// ============================================================================
852
853#[cfg(test)]
854mod tests {
855    use super::*;
856    use scirs2_core::ndarray::array;
857
858    #[test]
859    fn test_genomic_kernel_basic() {
860        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
861
862        let kernel = GenomicKernel::new(3, 50);
863        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
864        let features = fitted.transform(&x).expect("operation should succeed");
865
866        assert_eq!(features.shape(), &[3, 50]);
867    }
868
869    #[test]
870    fn test_genomic_kernel_normalization() {
871        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
872
873        let kernel = GenomicKernel::new(3, 30).normalize(false);
874        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
875        let features = fitted.transform(&x).expect("operation should succeed");
876
877        assert_eq!(features.shape(), &[2, 30]);
878    }
879
880    #[test]
881    fn test_protein_kernel_basic() {
882        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
883
884        let kernel = ProteinKernel::new(2, 40);
885        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
886        let features = fitted.transform(&x).expect("operation should succeed");
887
888        assert_eq!(features.shape(), &[2, 40]);
889    }
890
891    #[test]
892    fn test_protein_kernel_properties() {
893        let x = array![[1.0, 2.0], [3.0, 4.0]];
894
895        let kernel = ProteinKernel::new(2, 30).use_properties(true);
896        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
897        let features = fitted.transform(&x).expect("operation should succeed");
898
899        assert_eq!(features.shape(), &[2, 30]);
900        assert!(fitted.property_weights.is_some());
901    }
902
903    #[test]
904    fn test_phylogenetic_kernel_basic() {
905        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
906
907        let kernel = PhylogeneticKernel::new(50, 4);
908        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
909        let features = fitted.transform(&x).expect("operation should succeed");
910
911        assert_eq!(features.shape(), &[2, 50]);
912    }
913
914    #[test]
915    fn test_phylogenetic_kernel_branch_lengths() {
916        let x = array![[1.0, 2.0], [3.0, 4.0]];
917
918        let kernel = PhylogeneticKernel::new(40, 3).use_branch_lengths(true);
919        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
920        let features = fitted.transform(&x).expect("operation should succeed");
921
922        assert_eq!(features.shape(), &[2, 40]);
923        assert!(fitted.branch_weights.is_some());
924    }
925
926    #[test]
927    fn test_metabolic_network_kernel_basic() {
928        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
929
930        let kernel = MetabolicNetworkKernel::new(60, 3);
931        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
932        let features = fitted.transform(&x).expect("operation should succeed");
933
934        assert_eq!(features.shape(), &[2, 60]);
935    }
936
937    #[test]
938    fn test_metabolic_network_kernel_pathways() {
939        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
940
941        let kernel = MetabolicNetworkKernel::new(50, 3).use_pathway_enrichment(true);
942        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
943        let features = fitted.transform(&x).expect("operation should succeed");
944
945        assert_eq!(features.shape(), &[2, 50]);
946        assert!(fitted.pathway_weights.is_some());
947    }
948
949    #[test]
950    fn test_multi_omics_kernel_basic() {
951        let x = array![
952            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
953            [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
954        ];
955
956        let kernel = MultiOmicsKernel::new(40, 3);
957        let fitted = kernel.fit(&x, &()).expect("operation should succeed");
958        let features = fitted.transform(&x).expect("operation should succeed");
959
960        assert_eq!(features.shape(), &[2, 40]);
961    }
962
963    #[test]
964    fn test_multi_omics_integration_methods() {
965        let x = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
966
967        let methods = vec![
968            OmicsIntegrationMethod::Concatenation,
969            OmicsIntegrationMethod::WeightedAverage,
970            OmicsIntegrationMethod::CrossCorrelation,
971            OmicsIntegrationMethod::MultiViewLearning,
972        ];
973
974        for method in methods {
975            let kernel = MultiOmicsKernel::new(30, 2).integration_method(method);
976            let fitted = kernel.fit(&x, &()).expect("operation should succeed");
977            let features = fitted.transform(&x).expect("operation should succeed");
978            assert_eq!(features.shape(), &[2, 30]);
979        }
980    }
981
982    #[test]
983    fn test_empty_input_error() {
984        let x_empty: Array2<Float> = Array2::zeros((0, 3));
985
986        let kernel = GenomicKernel::new(3, 50);
987        assert!(kernel.fit(&x_empty, &()).is_err());
988
989        let kernel2 = ProteinKernel::new(2, 40);
990        assert!(kernel2.fit(&x_empty, &()).is_err());
991    }
992}