1use scirs2_core::ndarray::Array2;
30use sklears_core::{
31 error::Result,
32 prelude::{Fit, Transform},
33};
34use std::collections::HashMap;
35
36#[derive(Debug, Clone)]
38pub struct NGramKernel {
40 n: usize,
42 normalize: bool,
44 binary: bool,
46 mode: NGramMode,
48}
49
50#[derive(Debug, Clone)]
52pub enum NGramMode {
54 Character,
55 Word,
56 Custom { delimiter: String },
57}
58
59#[derive(Debug, Clone)]
61pub struct FittedNGramKernel {
63 vocabulary: HashMap<String, usize>,
65 n: usize,
67 normalize: bool,
69 binary: bool,
71 mode: NGramMode,
73}
74
75impl NGramKernel {
76 pub fn new(n: usize) -> Self {
78 Self {
79 n,
80 normalize: true,
81 binary: false,
82 mode: NGramMode::Character,
83 }
84 }
85
86 pub fn normalize(mut self, normalize: bool) -> Self {
88 self.normalize = normalize;
89 self
90 }
91
92 pub fn binary(mut self, binary: bool) -> Self {
94 self.binary = binary;
95 self
96 }
97
98 pub fn mode(mut self, mode: NGramMode) -> Self {
100 self.mode = mode;
101 self
102 }
103
104 fn extract_ngrams(&self, sequence: &str) -> Vec<String> {
106 match &self.mode {
107 NGramMode::Character => {
108 let chars: Vec<char> = sequence.chars().collect();
109 chars
110 .windows(self.n)
111 .map(|window| window.iter().collect())
112 .collect()
113 }
114 NGramMode::Word => {
115 let words: Vec<&str> = sequence.split_whitespace().collect();
116 words
117 .windows(self.n)
118 .map(|window| window.join(" "))
119 .collect()
120 }
121 NGramMode::Custom { delimiter } => {
122 let tokens: Vec<&str> = sequence.split(delimiter).collect();
123 tokens
124 .windows(self.n)
125 .map(|window| window.join(delimiter))
126 .collect()
127 }
128 }
129 }
130}
131
132impl Fit<Vec<String>, ()> for NGramKernel {
133 type Fitted = FittedNGramKernel;
134
135 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
136 let mut vocabulary = HashMap::new();
137 let mut vocab_index = 0;
138
139 for sequence in sequences {
141 let ngrams = self.extract_ngrams(sequence);
142 for ngram in ngrams {
143 if !vocabulary.contains_key(&ngram) {
144 vocabulary.insert(ngram, vocab_index);
145 vocab_index += 1;
146 }
147 }
148 }
149
150 Ok(FittedNGramKernel {
151 vocabulary,
152 n: self.n,
153 normalize: self.normalize,
154 binary: self.binary,
155 mode: self.mode.clone(),
156 })
157 }
158}
159
160impl Transform<Vec<String>, Array2<f64>> for FittedNGramKernel {
161 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
162 let n_sequences = sequences.len();
163 let vocab_size = self.vocabulary.len();
164 let mut features = Array2::zeros((n_sequences, vocab_size));
165
166 for (i, sequence) in sequences.iter().enumerate() {
167 let ngrams = match &self.mode {
168 NGramMode::Character => {
169 let chars: Vec<char> = sequence.chars().collect();
170 chars
171 .windows(self.n)
172 .map(|window| window.iter().collect::<String>())
173 .collect::<Vec<String>>()
174 }
175 NGramMode::Word => {
176 let words: Vec<&str> = sequence.split_whitespace().collect();
177 words
178 .windows(self.n)
179 .map(|window| window.join(" "))
180 .collect::<Vec<String>>()
181 }
182 NGramMode::Custom { delimiter } => {
183 let tokens: Vec<&str> = sequence.split(delimiter).collect();
184 tokens
185 .windows(self.n)
186 .map(|window| window.join(delimiter))
187 .collect::<Vec<String>>()
188 }
189 };
190
191 let mut ngram_counts = HashMap::new();
193 for ngram in ngrams {
194 if let Some(&vocab_idx) = self.vocabulary.get(&ngram) {
195 *ngram_counts.entry(vocab_idx).or_insert(0) += 1;
196 }
197 }
198
199 for (vocab_idx, count) in ngram_counts {
201 features[(i, vocab_idx)] = if self.binary { 1.0 } else { count as f64 };
202 }
203
204 if self.normalize {
206 let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
207 if norm > 0.0 {
208 for j in 0..vocab_size {
209 features[(i, j)] /= norm;
210 }
211 }
212 }
213 }
214
215 Ok(features)
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct SpectrumKernel {
223 k: usize,
225 normalize: bool,
227}
228
229impl SpectrumKernel {
230 pub fn new(k: usize) -> Self {
232 Self { k, normalize: true }
233 }
234
235 pub fn normalize(mut self, normalize: bool) -> Self {
237 self.normalize = normalize;
238 self
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct FittedSpectrumKernel {
246 vocabulary: HashMap<String, usize>,
248 k: usize,
250 normalize: bool,
252}
253
254impl Fit<Vec<String>, ()> for SpectrumKernel {
255 type Fitted = FittedSpectrumKernel;
256
257 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
258 let mut vocabulary = HashMap::new();
259 let mut vocab_index = 0;
260
261 for sequence in sequences {
263 let chars: Vec<char> = sequence.chars().collect();
264 for window in chars.windows(self.k) {
265 let kmer: String = window.iter().collect();
266 if !vocabulary.contains_key(&kmer) {
267 vocabulary.insert(kmer, vocab_index);
268 vocab_index += 1;
269 }
270 }
271 }
272
273 Ok(FittedSpectrumKernel {
274 vocabulary,
275 k: self.k,
276 normalize: self.normalize,
277 })
278 }
279}
280
281impl Transform<Vec<String>, Array2<f64>> for FittedSpectrumKernel {
282 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
283 let n_sequences = sequences.len();
284 let vocab_size = self.vocabulary.len();
285 let mut features = Array2::zeros((n_sequences, vocab_size));
286
287 for (i, sequence) in sequences.iter().enumerate() {
288 let chars: Vec<char> = sequence.chars().collect();
289 let mut kmer_counts = HashMap::new();
290
291 for window in chars.windows(self.k) {
293 let kmer: String = window.iter().collect();
294 if let Some(&vocab_idx) = self.vocabulary.get(&kmer) {
295 *kmer_counts.entry(vocab_idx).or_insert(0) += 1;
296 }
297 }
298
299 for (vocab_idx, count) in kmer_counts {
301 features[(i, vocab_idx)] = count as f64;
302 }
303
304 if self.normalize {
306 let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
307 if norm > 0.0 {
308 for j in 0..vocab_size {
309 features[(i, j)] /= norm;
310 }
311 }
312 }
313 }
314
315 Ok(features)
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct SubsequenceKernel {
323 max_length: usize,
325 gap_penalty: f64,
327 normalize: bool,
329}
330
331impl SubsequenceKernel {
332 pub fn new(max_length: usize, gap_penalty: f64) -> Self {
334 Self {
335 max_length,
336 gap_penalty,
337 normalize: true,
338 }
339 }
340
341 pub fn normalize(mut self, normalize: bool) -> Self {
343 self.normalize = normalize;
344 self
345 }
346
347 fn subsequence_kernel_value(&self, s1: &str, s2: &str) -> f64 {
349 let chars1: Vec<char> = s1.chars().collect();
350 let chars2: Vec<char> = s2.chars().collect();
351 let n1 = chars1.len();
352 let n2 = chars2.len();
353
354 if n1 == 0 || n2 == 0 {
355 return 0.0;
356 }
357
358 let mut dp = vec![vec![vec![0.0; n2 + 1]; n1 + 1]; self.max_length + 1];
359
360 for i in 0..=n1 {
362 for j in 0..=n2 {
363 dp[0][i][j] = 1.0;
364 }
365 }
366
367 for k in 1..=self.max_length {
369 for i in 1..=n1 {
370 for j in 1..=n2 {
371 dp[k][i][j] = self.gap_penalty * dp[k][i - 1][j];
373
374 if chars1[i - 1] == chars2[j - 1] {
376 dp[k][i][j] += self.gap_penalty * dp[k - 1][i - 1][j - 1];
377 }
378
379 dp[k][i][j] += self.gap_penalty * dp[k][i][j - 1];
381
382 if chars1[i - 1] == chars2[j - 1] {
384 dp[k][i][j] -=
385 self.gap_penalty * self.gap_penalty * dp[k - 1][i - 1][j - 1];
386 }
387 }
388 }
389 }
390
391 let mut total = 0.0;
393 for k in 1..=self.max_length {
394 total += dp[k][n1][n2];
395 }
396
397 total
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct FittedSubsequenceKernel {
405 training_sequences: Vec<String>,
407 max_length: usize,
409 gap_penalty: f64,
411 normalize: bool,
413}
414
415impl Fit<Vec<String>, ()> for SubsequenceKernel {
416 type Fitted = FittedSubsequenceKernel;
417
418 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
419 Ok(FittedSubsequenceKernel {
420 training_sequences: sequences.clone(),
421 max_length: self.max_length,
422 gap_penalty: self.gap_penalty,
423 normalize: self.normalize,
424 })
425 }
426}
427
428impl Transform<Vec<String>, Array2<f64>> for FittedSubsequenceKernel {
429 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
430 let n_test = sequences.len();
431 let n_train = self.training_sequences.len();
432 let mut kernel_matrix = Array2::zeros((n_test, n_train));
433
434 let kernel = SubsequenceKernel {
436 max_length: self.max_length,
437 gap_penalty: self.gap_penalty,
438 normalize: false, };
440
441 for i in 0..n_test {
442 for j in 0..n_train {
443 kernel_matrix[(i, j)] =
444 kernel.subsequence_kernel_value(&sequences[i], &self.training_sequences[j]);
445 }
446
447 if self.normalize {
449 let norm = kernel_matrix.row(i).mapv(|x| x * x).sum().sqrt();
450 if norm > 0.0 {
451 for j in 0..n_train {
452 kernel_matrix[(i, j)] /= norm;
453 }
454 }
455 }
456 }
457
458 Ok(kernel_matrix)
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct EditDistanceKernel {
466 max_distance: usize,
468 sigma: f64,
470}
471
472impl EditDistanceKernel {
473 pub fn new(max_distance: usize, sigma: f64) -> Self {
475 Self {
476 max_distance,
477 sigma,
478 }
479 }
480
481 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
483 let chars1: Vec<char> = s1.chars().collect();
484 let chars2: Vec<char> = s2.chars().collect();
485 let n1 = chars1.len();
486 let n2 = chars2.len();
487
488 let mut dp = vec![vec![0; n2 + 1]; n1 + 1];
489
490 for i in 0..=n1 {
492 dp[i][0] = i;
493 }
494 for j in 0..=n2 {
495 dp[0][j] = j;
496 }
497
498 for i in 1..=n1 {
500 for j in 1..=n2 {
501 let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
502 dp[i][j] = (dp[i - 1][j] + 1)
503 .min(dp[i][j - 1] + 1)
504 .min(dp[i - 1][j - 1] + cost);
505 }
506 }
507
508 dp[n1][n2]
509 }
510
511 fn kernel_value(&self, s1: &str, s2: &str) -> f64 {
513 let distance = self.edit_distance(s1, s2);
514 if distance > self.max_distance {
515 0.0
516 } else {
517 (-(distance as f64) / self.sigma).exp()
518 }
519 }
520}
521
522#[derive(Debug, Clone)]
524pub struct FittedEditDistanceKernel {
526 training_sequences: Vec<String>,
528 max_distance: usize,
530 sigma: f64,
532}
533
534impl Fit<Vec<String>, ()> for EditDistanceKernel {
535 type Fitted = FittedEditDistanceKernel;
536
537 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
538 Ok(FittedEditDistanceKernel {
539 training_sequences: sequences.clone(),
540 max_distance: self.max_distance,
541 sigma: self.sigma,
542 })
543 }
544}
545
546impl Transform<Vec<String>, Array2<f64>> for FittedEditDistanceKernel {
547 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
548 let n_test = sequences.len();
549 let n_train = self.training_sequences.len();
550 let mut kernel_matrix = Array2::zeros((n_test, n_train));
551
552 let kernel = EditDistanceKernel {
553 max_distance: self.max_distance,
554 sigma: self.sigma,
555 };
556
557 for i in 0..n_test {
558 for j in 0..n_train {
559 kernel_matrix[(i, j)] =
560 kernel.kernel_value(&sequences[i], &self.training_sequences[j]);
561 }
562 }
563
564 Ok(kernel_matrix)
565 }
566}
567
568#[derive(Debug, Clone)]
570pub struct MismatchKernel {
572 k: usize,
574 m: usize,
576 alphabet: Vec<char>,
578}
579
580impl MismatchKernel {
581 pub fn new(k: usize, m: usize) -> Self {
583 let alphabet = vec!['A', 'C', 'G', 'T'];
585 Self { k, m, alphabet }
586 }
587
588 pub fn alphabet(mut self, alphabet: Vec<char>) -> Self {
590 self.alphabet = alphabet;
591 self
592 }
593
594 fn generate_neighborhood(&self, kmer: &str, mismatches: usize) -> Vec<String> {
596 if mismatches == 0 {
597 return vec![kmer.to_string()];
598 }
599
600 let chars: Vec<char> = kmer.chars().collect();
601 let mut neighborhood = Vec::new();
602
603 self.generate_mismatches(&chars, 0, mismatches, &mut vec![], &mut neighborhood);
605
606 neighborhood
607 }
608
609 fn generate_mismatches(
611 &self,
612 original: &[char],
613 pos: usize,
614 mismatches_left: usize,
615 current: &mut Vec<char>,
616 result: &mut Vec<String>,
617 ) {
618 if pos == original.len() {
619 if mismatches_left == 0 {
620 result.push(current.iter().collect());
621 }
622 return;
623 }
624
625 current.push(original[pos]);
627 self.generate_mismatches(original, pos + 1, mismatches_left, current, result);
628 current.pop();
629
630 if mismatches_left > 0 {
632 for &c in &self.alphabet {
633 if c != original[pos] {
634 current.push(c);
635 self.generate_mismatches(
636 original,
637 pos + 1,
638 mismatches_left - 1,
639 current,
640 result,
641 );
642 current.pop();
643 }
644 }
645 }
646 }
647}
648
649#[derive(Debug, Clone)]
651pub struct FittedMismatchKernel {
653 vocabulary: HashMap<String, usize>,
655 k: usize,
657 m: usize,
659 alphabet: Vec<char>,
661}
662
663impl Fit<Vec<String>, ()> for MismatchKernel {
664 type Fitted = FittedMismatchKernel;
665
666 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
667 let mut vocabulary = HashMap::new();
668 let mut vocab_index = 0;
669
670 for sequence in sequences {
672 let chars: Vec<char> = sequence.chars().collect();
673 for window in chars.windows(self.k) {
674 let kmer: String = window.iter().collect();
675
676 for mismatch_count in 0..=self.m {
678 let neighborhood = self.generate_neighborhood(&kmer, mismatch_count);
679 for neighbor in neighborhood {
680 if !vocabulary.contains_key(&neighbor) {
681 vocabulary.insert(neighbor, vocab_index);
682 vocab_index += 1;
683 }
684 }
685 }
686 }
687 }
688
689 Ok(FittedMismatchKernel {
690 vocabulary,
691 k: self.k,
692 m: self.m,
693 alphabet: self.alphabet.clone(),
694 })
695 }
696}
697
698impl Transform<Vec<String>, Array2<f64>> for FittedMismatchKernel {
699 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
700 let n_sequences = sequences.len();
701 let vocab_size = self.vocabulary.len();
702 let mut features = Array2::zeros((n_sequences, vocab_size));
703
704 let kernel = MismatchKernel {
705 k: self.k,
706 m: self.m,
707 alphabet: self.alphabet.clone(),
708 };
709
710 for (i, sequence) in sequences.iter().enumerate() {
711 let chars: Vec<char> = sequence.chars().collect();
712 let mut feature_counts = HashMap::new();
713
714 for window in chars.windows(self.k) {
716 let kmer: String = window.iter().collect();
717
718 for mismatch_count in 0..=self.m {
719 let neighborhood = kernel.generate_neighborhood(&kmer, mismatch_count);
720 for neighbor in neighborhood {
721 if let Some(&vocab_idx) = self.vocabulary.get(&neighbor) {
722 *feature_counts.entry(vocab_idx).or_insert(0) += 1;
723 }
724 }
725 }
726 }
727
728 for (vocab_idx, count) in feature_counts {
730 features[(i, vocab_idx)] = count as f64;
731 }
732 }
733
734 Ok(features)
735 }
736}
737
738#[allow(non_snake_case)]
739#[cfg(test)]
740mod tests {
741 use super::*;
742 use approx::assert_abs_diff_eq;
743
744 #[test]
745 fn test_ngram_kernel_character() {
746 let kernel = NGramKernel::new(2).mode(NGramMode::Character);
747 let sequences = vec!["hello".to_string(), "world".to_string(), "help".to_string()];
748
749 let fitted = kernel.fit(&sequences, &()).unwrap();
750 let features = fitted.transform(&sequences).unwrap();
751
752 assert_eq!(features.nrows(), 3);
753 assert!(features.ncols() > 0);
754 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
755 }
756
757 #[test]
758 fn test_ngram_kernel_word() {
759 let kernel = NGramKernel::new(2).mode(NGramMode::Word);
760 let sequences = vec![
761 "hello world".to_string(),
762 "world peace".to_string(),
763 "hello there".to_string(),
764 ];
765
766 let fitted = kernel.fit(&sequences, &()).unwrap();
767 let features = fitted.transform(&sequences).unwrap();
768
769 assert_eq!(features.nrows(), 3);
770 assert!(features.ncols() > 0);
771 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
772 }
773
774 #[test]
775 fn test_spectrum_kernel() {
776 let kernel = SpectrumKernel::new(3);
777 let sequences = vec![
778 "ATCGATCG".to_string(),
779 "GCTAGCTA".to_string(),
780 "ATCGATCG".to_string(), ];
782
783 let fitted = kernel.fit(&sequences, &()).unwrap();
784 let features = fitted.transform(&sequences).unwrap();
785
786 assert_eq!(features.nrows(), 3);
787 assert!(features.ncols() > 0);
788
789 for j in 0..features.ncols() {
791 assert_abs_diff_eq!(features[(0, j)], features[(2, j)], epsilon = 1e-10);
792 }
793 }
794
795 #[test]
796 fn test_subsequence_kernel() {
797 let kernel = SubsequenceKernel::new(3, 0.5);
798 let sequences = vec!["ABC".to_string(), "ACB".to_string(), "ABC".to_string()];
799
800 let fitted = kernel.fit(&sequences, &()).unwrap();
801 let features = fitted.transform(&sequences).unwrap();
802
803 assert_eq!(features.nrows(), 3);
804 assert_eq!(features.ncols(), 3);
805 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
806
807 assert!(features[(0, 0)] > 0.0);
809 assert_abs_diff_eq!(features[(0, 0)], features[(2, 0)], epsilon = 1e-10);
810 }
811
812 #[test]
813 fn test_edit_distance_kernel() {
814 let kernel = EditDistanceKernel::new(5, 1.0);
815 let sequences = vec![
816 "cat".to_string(),
817 "bat".to_string(),
818 "rat".to_string(),
819 "dog".to_string(),
820 ];
821
822 let fitted = kernel.fit(&sequences, &()).unwrap();
823 let features = fitted.transform(&sequences).unwrap();
824
825 assert_eq!(features.nrows(), 4);
826 assert_eq!(features.ncols(), 4);
827 assert!(features
828 .iter()
829 .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
830
831 for i in 0..4 {
833 assert_abs_diff_eq!(features[(i, i)], 1.0, epsilon = 1e-10);
834 }
835 }
836
837 #[test]
838 fn test_mismatch_kernel() {
839 let kernel = MismatchKernel::new(3, 1).alphabet(vec!['A', 'C', 'G', 'T']);
840 let sequences = vec![
841 "ATCG".to_string(),
842 "ATCC".to_string(), "GCTA".to_string(),
844 ];
845
846 let fitted = kernel.fit(&sequences, &()).unwrap();
847 let features = fitted.transform(&sequences).unwrap();
848
849 assert_eq!(features.nrows(), 3);
850 assert!(features.ncols() > 0);
851 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
852 }
853
854 #[test]
855 fn test_edit_distance_computation() {
856 let kernel = EditDistanceKernel::new(10, 1.0);
857
858 assert_eq!(kernel.edit_distance("", ""), 0);
859 assert_eq!(kernel.edit_distance("cat", "cat"), 0);
860 assert_eq!(kernel.edit_distance("cat", "bat"), 1);
861 assert_eq!(kernel.edit_distance("cat", "dog"), 3);
862 assert_eq!(kernel.edit_distance("kitten", "sitting"), 3);
863 }
864
865 #[test]
866 fn test_ngram_binary_mode() {
867 let kernel = NGramKernel::new(2).binary(true).normalize(false);
868
869 let sequences = vec![
870 "aaa".to_string(), "aba".to_string(), ];
873
874 let fitted = kernel.fit(&sequences, &()).unwrap();
875 let features = fitted.transform(&sequences).unwrap();
876
877 assert!(features.iter().all(|&x| x == 0.0 || x == 1.0));
879 }
880}