1use scirs2_core::ndarray::Array2;
30use sklears_core::{
31 error::Result,
32 prelude::{Fit, Transform},
33};
34use std::collections::HashMap;
35
36#[derive(Debug, Clone)]
38pub struct NGramKernel {
40 n: usize,
42 normalize: bool,
44 binary: bool,
46 mode: NGramMode,
48}
49
50#[derive(Debug, Clone)]
52pub enum NGramMode {
54 Character,
55 Word,
56 Custom { delimiter: String },
57}
58
59#[derive(Debug, Clone)]
61pub struct FittedNGramKernel {
63 vocabulary: HashMap<String, usize>,
65 n: usize,
67 normalize: bool,
69 binary: bool,
71 mode: NGramMode,
73}
74
75impl NGramKernel {
76 pub fn new(n: usize) -> Self {
78 Self {
79 n,
80 normalize: true,
81 binary: false,
82 mode: NGramMode::Character,
83 }
84 }
85
86 pub fn normalize(mut self, normalize: bool) -> Self {
88 self.normalize = normalize;
89 self
90 }
91
92 pub fn binary(mut self, binary: bool) -> Self {
94 self.binary = binary;
95 self
96 }
97
98 pub fn mode(mut self, mode: NGramMode) -> Self {
100 self.mode = mode;
101 self
102 }
103
104 fn extract_ngrams(&self, sequence: &str) -> Vec<String> {
106 match &self.mode {
107 NGramMode::Character => {
108 let chars: Vec<char> = sequence.chars().collect();
109 chars
110 .windows(self.n)
111 .map(|window| window.iter().collect())
112 .collect()
113 }
114 NGramMode::Word => {
115 let words: Vec<&str> = sequence.split_whitespace().collect();
116 words
117 .windows(self.n)
118 .map(|window| window.join(" "))
119 .collect()
120 }
121 NGramMode::Custom { delimiter } => {
122 let tokens: Vec<&str> = sequence.split(delimiter).collect();
123 tokens
124 .windows(self.n)
125 .map(|window| window.join(delimiter))
126 .collect()
127 }
128 }
129 }
130}
131
132impl Fit<Vec<String>, ()> for NGramKernel {
133 type Fitted = FittedNGramKernel;
134
135 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
136 let mut vocabulary = HashMap::new();
137 let mut vocab_index = 0;
138
139 for sequence in sequences {
141 let ngrams = self.extract_ngrams(sequence);
142 for ngram in ngrams {
143 if let std::collections::hash_map::Entry::Vacant(e) = vocabulary.entry(ngram) {
144 e.insert(vocab_index);
145 vocab_index += 1;
146 }
147 }
148 }
149
150 Ok(FittedNGramKernel {
151 vocabulary,
152 n: self.n,
153 normalize: self.normalize,
154 binary: self.binary,
155 mode: self.mode.clone(),
156 })
157 }
158}
159
160impl Transform<Vec<String>, Array2<f64>> for FittedNGramKernel {
161 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
162 let n_sequences = sequences.len();
163 let vocab_size = self.vocabulary.len();
164 let mut features = Array2::zeros((n_sequences, vocab_size));
165
166 for (i, sequence) in sequences.iter().enumerate() {
167 let ngrams = match &self.mode {
168 NGramMode::Character => {
169 let chars: Vec<char> = sequence.chars().collect();
170 chars
171 .windows(self.n)
172 .map(|window| window.iter().collect::<String>())
173 .collect::<Vec<String>>()
174 }
175 NGramMode::Word => {
176 let words: Vec<&str> = sequence.split_whitespace().collect();
177 words
178 .windows(self.n)
179 .map(|window| window.join(" "))
180 .collect::<Vec<String>>()
181 }
182 NGramMode::Custom { delimiter } => {
183 let tokens: Vec<&str> = sequence.split(delimiter).collect();
184 tokens
185 .windows(self.n)
186 .map(|window| window.join(delimiter))
187 .collect::<Vec<String>>()
188 }
189 };
190
191 let mut ngram_counts = HashMap::new();
193 for ngram in ngrams {
194 if let Some(&vocab_idx) = self.vocabulary.get(&ngram) {
195 *ngram_counts.entry(vocab_idx).or_insert(0) += 1;
196 }
197 }
198
199 for (vocab_idx, count) in ngram_counts {
201 features[(i, vocab_idx)] = if self.binary { 1.0 } else { count as f64 };
202 }
203
204 if self.normalize {
206 let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
207 if norm > 0.0 {
208 for j in 0..vocab_size {
209 features[(i, j)] /= norm;
210 }
211 }
212 }
213 }
214
215 Ok(features)
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct SpectrumKernel {
223 k: usize,
225 normalize: bool,
227}
228
229impl SpectrumKernel {
230 pub fn new(k: usize) -> Self {
232 Self { k, normalize: true }
233 }
234
235 pub fn normalize(mut self, normalize: bool) -> Self {
237 self.normalize = normalize;
238 self
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct FittedSpectrumKernel {
246 vocabulary: HashMap<String, usize>,
248 k: usize,
250 normalize: bool,
252}
253
254impl Fit<Vec<String>, ()> for SpectrumKernel {
255 type Fitted = FittedSpectrumKernel;
256
257 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
258 let mut vocabulary = HashMap::new();
259 let mut vocab_index = 0;
260
261 for sequence in sequences {
263 let chars: Vec<char> = sequence.chars().collect();
264 for window in chars.windows(self.k) {
265 let kmer: String = window.iter().collect();
266 if let std::collections::hash_map::Entry::Vacant(e) = vocabulary.entry(kmer) {
267 e.insert(vocab_index);
268 vocab_index += 1;
269 }
270 }
271 }
272
273 Ok(FittedSpectrumKernel {
274 vocabulary,
275 k: self.k,
276 normalize: self.normalize,
277 })
278 }
279}
280
281impl Transform<Vec<String>, Array2<f64>> for FittedSpectrumKernel {
282 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
283 let n_sequences = sequences.len();
284 let vocab_size = self.vocabulary.len();
285 let mut features = Array2::zeros((n_sequences, vocab_size));
286
287 for (i, sequence) in sequences.iter().enumerate() {
288 let chars: Vec<char> = sequence.chars().collect();
289 let mut kmer_counts = HashMap::new();
290
291 for window in chars.windows(self.k) {
293 let kmer: String = window.iter().collect();
294 if let Some(&vocab_idx) = self.vocabulary.get(&kmer) {
295 *kmer_counts.entry(vocab_idx).or_insert(0) += 1;
296 }
297 }
298
299 for (vocab_idx, count) in kmer_counts {
301 features[(i, vocab_idx)] = count as f64;
302 }
303
304 if self.normalize {
306 let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
307 if norm > 0.0 {
308 for j in 0..vocab_size {
309 features[(i, j)] /= norm;
310 }
311 }
312 }
313 }
314
315 Ok(features)
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct SubsequenceKernel {
323 max_length: usize,
325 gap_penalty: f64,
327 normalize: bool,
329}
330
331impl SubsequenceKernel {
332 pub fn new(max_length: usize, gap_penalty: f64) -> Self {
334 Self {
335 max_length,
336 gap_penalty,
337 normalize: true,
338 }
339 }
340
341 pub fn normalize(mut self, normalize: bool) -> Self {
343 self.normalize = normalize;
344 self
345 }
346
347 fn subsequence_kernel_value(&self, s1: &str, s2: &str) -> f64 {
349 let chars1: Vec<char> = s1.chars().collect();
350 let chars2: Vec<char> = s2.chars().collect();
351 let n1 = chars1.len();
352 let n2 = chars2.len();
353
354 if n1 == 0 || n2 == 0 {
355 return 0.0;
356 }
357
358 let mut dp = vec![vec![vec![0.0; n2 + 1]; n1 + 1]; self.max_length + 1];
359
360 for i in 0..=n1 {
362 for j in 0..=n2 {
363 dp[0][i][j] = 1.0;
364 }
365 }
366
367 for k in 1..=self.max_length {
369 for i in 1..=n1 {
370 for j in 1..=n2 {
371 dp[k][i][j] = self.gap_penalty * dp[k][i - 1][j];
373
374 if chars1[i - 1] == chars2[j - 1] {
376 dp[k][i][j] += self.gap_penalty * dp[k - 1][i - 1][j - 1];
377 }
378
379 dp[k][i][j] += self.gap_penalty * dp[k][i][j - 1];
381
382 if chars1[i - 1] == chars2[j - 1] {
384 dp[k][i][j] -=
385 self.gap_penalty * self.gap_penalty * dp[k - 1][i - 1][j - 1];
386 }
387 }
388 }
389 }
390
391 let mut total = 0.0;
393 for k in 1..=self.max_length {
394 total += dp[k][n1][n2];
395 }
396
397 total
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct FittedSubsequenceKernel {
405 training_sequences: Vec<String>,
407 max_length: usize,
409 gap_penalty: f64,
411 normalize: bool,
413}
414
415impl Fit<Vec<String>, ()> for SubsequenceKernel {
416 type Fitted = FittedSubsequenceKernel;
417
418 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
419 Ok(FittedSubsequenceKernel {
420 training_sequences: sequences.clone(),
421 max_length: self.max_length,
422 gap_penalty: self.gap_penalty,
423 normalize: self.normalize,
424 })
425 }
426}
427
428impl Transform<Vec<String>, Array2<f64>> for FittedSubsequenceKernel {
429 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
430 let n_test = sequences.len();
431 let n_train = self.training_sequences.len();
432 let mut kernel_matrix = Array2::zeros((n_test, n_train));
433
434 let kernel = SubsequenceKernel {
436 max_length: self.max_length,
437 gap_penalty: self.gap_penalty,
438 normalize: false, };
440
441 for i in 0..n_test {
442 for j in 0..n_train {
443 kernel_matrix[(i, j)] =
444 kernel.subsequence_kernel_value(&sequences[i], &self.training_sequences[j]);
445 }
446
447 if self.normalize {
449 let norm = kernel_matrix.row(i).mapv(|x| x * x).sum().sqrt();
450 if norm > 0.0 {
451 for j in 0..n_train {
452 kernel_matrix[(i, j)] /= norm;
453 }
454 }
455 }
456 }
457
458 Ok(kernel_matrix)
459 }
460}
461
462#[derive(Debug, Clone)]
464pub struct EditDistanceKernel {
466 max_distance: usize,
468 sigma: f64,
470}
471
472impl EditDistanceKernel {
473 pub fn new(max_distance: usize, sigma: f64) -> Self {
475 Self {
476 max_distance,
477 sigma,
478 }
479 }
480
481 fn edit_distance(&self, s1: &str, s2: &str) -> usize {
483 let chars1: Vec<char> = s1.chars().collect();
484 let chars2: Vec<char> = s2.chars().collect();
485 let n1 = chars1.len();
486 let n2 = chars2.len();
487
488 let mut dp = vec![vec![0; n2 + 1]; n1 + 1];
489
490 for i in 0..=n1 {
492 dp[i][0] = i;
493 }
494 for j in 0..=n2 {
495 dp[0][j] = j;
496 }
497
498 for i in 1..=n1 {
500 for j in 1..=n2 {
501 let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
502 dp[i][j] = (dp[i - 1][j] + 1)
503 .min(dp[i][j - 1] + 1)
504 .min(dp[i - 1][j - 1] + cost);
505 }
506 }
507
508 dp[n1][n2]
509 }
510
511 fn kernel_value(&self, s1: &str, s2: &str) -> f64 {
513 let distance = self.edit_distance(s1, s2);
514 if distance > self.max_distance {
515 0.0
516 } else {
517 (-(distance as f64) / self.sigma).exp()
518 }
519 }
520}
521
522#[derive(Debug, Clone)]
524pub struct FittedEditDistanceKernel {
526 training_sequences: Vec<String>,
528 max_distance: usize,
530 sigma: f64,
532}
533
534impl Fit<Vec<String>, ()> for EditDistanceKernel {
535 type Fitted = FittedEditDistanceKernel;
536
537 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
538 Ok(FittedEditDistanceKernel {
539 training_sequences: sequences.clone(),
540 max_distance: self.max_distance,
541 sigma: self.sigma,
542 })
543 }
544}
545
546impl Transform<Vec<String>, Array2<f64>> for FittedEditDistanceKernel {
547 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
548 let n_test = sequences.len();
549 let n_train = self.training_sequences.len();
550 let mut kernel_matrix = Array2::zeros((n_test, n_train));
551
552 let kernel = EditDistanceKernel {
553 max_distance: self.max_distance,
554 sigma: self.sigma,
555 };
556
557 for i in 0..n_test {
558 for j in 0..n_train {
559 kernel_matrix[(i, j)] =
560 kernel.kernel_value(&sequences[i], &self.training_sequences[j]);
561 }
562 }
563
564 Ok(kernel_matrix)
565 }
566}
567
568#[derive(Debug, Clone)]
570pub struct MismatchKernel {
572 k: usize,
574 m: usize,
576 alphabet: Vec<char>,
578}
579
580impl MismatchKernel {
581 pub fn new(k: usize, m: usize) -> Self {
583 let alphabet = vec!['A', 'C', 'G', 'T'];
585 Self { k, m, alphabet }
586 }
587
588 pub fn alphabet(mut self, alphabet: Vec<char>) -> Self {
590 self.alphabet = alphabet;
591 self
592 }
593
594 fn generate_neighborhood(&self, kmer: &str, mismatches: usize) -> Vec<String> {
596 if mismatches == 0 {
597 return vec![kmer.to_string()];
598 }
599
600 let chars: Vec<char> = kmer.chars().collect();
601 let mut neighborhood = Vec::new();
602
603 self.generate_mismatches(&chars, 0, mismatches, &mut vec![], &mut neighborhood);
605
606 neighborhood
607 }
608
609 fn generate_mismatches(
611 &self,
612 original: &[char],
613 pos: usize,
614 mismatches_left: usize,
615 current: &mut Vec<char>,
616 result: &mut Vec<String>,
617 ) {
618 if pos == original.len() {
619 if mismatches_left == 0 {
620 result.push(current.iter().collect());
621 }
622 return;
623 }
624
625 current.push(original[pos]);
627 self.generate_mismatches(original, pos + 1, mismatches_left, current, result);
628 current.pop();
629
630 if mismatches_left > 0 {
632 for &c in &self.alphabet {
633 if c != original[pos] {
634 current.push(c);
635 self.generate_mismatches(
636 original,
637 pos + 1,
638 mismatches_left - 1,
639 current,
640 result,
641 );
642 current.pop();
643 }
644 }
645 }
646 }
647}
648
649#[derive(Debug, Clone)]
651pub struct FittedMismatchKernel {
653 vocabulary: HashMap<String, usize>,
655 k: usize,
657 m: usize,
659 alphabet: Vec<char>,
661}
662
663impl Fit<Vec<String>, ()> for MismatchKernel {
664 type Fitted = FittedMismatchKernel;
665
666 fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
667 let mut vocabulary = HashMap::new();
668 let mut vocab_index = 0;
669
670 for sequence in sequences {
672 let chars: Vec<char> = sequence.chars().collect();
673 for window in chars.windows(self.k) {
674 let kmer: String = window.iter().collect();
675
676 for mismatch_count in 0..=self.m {
678 let neighborhood = self.generate_neighborhood(&kmer, mismatch_count);
679 for neighbor in neighborhood {
680 if let std::collections::hash_map::Entry::Vacant(e) =
681 vocabulary.entry(neighbor)
682 {
683 e.insert(vocab_index);
684 vocab_index += 1;
685 }
686 }
687 }
688 }
689 }
690
691 Ok(FittedMismatchKernel {
692 vocabulary,
693 k: self.k,
694 m: self.m,
695 alphabet: self.alphabet.clone(),
696 })
697 }
698}
699
700impl Transform<Vec<String>, Array2<f64>> for FittedMismatchKernel {
701 fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
702 let n_sequences = sequences.len();
703 let vocab_size = self.vocabulary.len();
704 let mut features = Array2::zeros((n_sequences, vocab_size));
705
706 let kernel = MismatchKernel {
707 k: self.k,
708 m: self.m,
709 alphabet: self.alphabet.clone(),
710 };
711
712 for (i, sequence) in sequences.iter().enumerate() {
713 let chars: Vec<char> = sequence.chars().collect();
714 let mut feature_counts = HashMap::new();
715
716 for window in chars.windows(self.k) {
718 let kmer: String = window.iter().collect();
719
720 for mismatch_count in 0..=self.m {
721 let neighborhood = kernel.generate_neighborhood(&kmer, mismatch_count);
722 for neighbor in neighborhood {
723 if let Some(&vocab_idx) = self.vocabulary.get(&neighbor) {
724 *feature_counts.entry(vocab_idx).or_insert(0) += 1;
725 }
726 }
727 }
728 }
729
730 for (vocab_idx, count) in feature_counts {
732 features[(i, vocab_idx)] = count as f64;
733 }
734 }
735
736 Ok(features)
737 }
738}
739
740#[allow(non_snake_case)]
741#[cfg(test)]
742mod tests {
743 use super::*;
744 use approx::assert_abs_diff_eq;
745
746 #[test]
747 fn test_ngram_kernel_character() {
748 let kernel = NGramKernel::new(2).mode(NGramMode::Character);
749 let sequences = vec!["hello".to_string(), "world".to_string(), "help".to_string()];
750
751 let fitted = kernel.fit(&sequences, &()).unwrap();
752 let features = fitted.transform(&sequences).unwrap();
753
754 assert_eq!(features.nrows(), 3);
755 assert!(features.ncols() > 0);
756 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
757 }
758
759 #[test]
760 fn test_ngram_kernel_word() {
761 let kernel = NGramKernel::new(2).mode(NGramMode::Word);
762 let sequences = vec![
763 "hello world".to_string(),
764 "world peace".to_string(),
765 "hello there".to_string(),
766 ];
767
768 let fitted = kernel.fit(&sequences, &()).unwrap();
769 let features = fitted.transform(&sequences).unwrap();
770
771 assert_eq!(features.nrows(), 3);
772 assert!(features.ncols() > 0);
773 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
774 }
775
776 #[test]
777 fn test_spectrum_kernel() {
778 let kernel = SpectrumKernel::new(3);
779 let sequences = vec![
780 "ATCGATCG".to_string(),
781 "GCTAGCTA".to_string(),
782 "ATCGATCG".to_string(), ];
784
785 let fitted = kernel.fit(&sequences, &()).unwrap();
786 let features = fitted.transform(&sequences).unwrap();
787
788 assert_eq!(features.nrows(), 3);
789 assert!(features.ncols() > 0);
790
791 for j in 0..features.ncols() {
793 assert_abs_diff_eq!(features[(0, j)], features[(2, j)], epsilon = 1e-10);
794 }
795 }
796
797 #[test]
798 fn test_subsequence_kernel() {
799 let kernel = SubsequenceKernel::new(3, 0.5);
800 let sequences = vec!["ABC".to_string(), "ACB".to_string(), "ABC".to_string()];
801
802 let fitted = kernel.fit(&sequences, &()).unwrap();
803 let features = fitted.transform(&sequences).unwrap();
804
805 assert_eq!(features.nrows(), 3);
806 assert_eq!(features.ncols(), 3);
807 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
808
809 assert!(features[(0, 0)] > 0.0);
811 assert_abs_diff_eq!(features[(0, 0)], features[(2, 0)], epsilon = 1e-10);
812 }
813
814 #[test]
815 fn test_edit_distance_kernel() {
816 let kernel = EditDistanceKernel::new(5, 1.0);
817 let sequences = vec![
818 "cat".to_string(),
819 "bat".to_string(),
820 "rat".to_string(),
821 "dog".to_string(),
822 ];
823
824 let fitted = kernel.fit(&sequences, &()).unwrap();
825 let features = fitted.transform(&sequences).unwrap();
826
827 assert_eq!(features.nrows(), 4);
828 assert_eq!(features.ncols(), 4);
829 assert!(features
830 .iter()
831 .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
832
833 for i in 0..4 {
835 assert_abs_diff_eq!(features[(i, i)], 1.0, epsilon = 1e-10);
836 }
837 }
838
839 #[test]
840 fn test_mismatch_kernel() {
841 let kernel = MismatchKernel::new(3, 1).alphabet(vec!['A', 'C', 'G', 'T']);
842 let sequences = vec![
843 "ATCG".to_string(),
844 "ATCC".to_string(), "GCTA".to_string(),
846 ];
847
848 let fitted = kernel.fit(&sequences, &()).unwrap();
849 let features = fitted.transform(&sequences).unwrap();
850
851 assert_eq!(features.nrows(), 3);
852 assert!(features.ncols() > 0);
853 assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
854 }
855
856 #[test]
857 fn test_edit_distance_computation() {
858 let kernel = EditDistanceKernel::new(10, 1.0);
859
860 assert_eq!(kernel.edit_distance("", ""), 0);
861 assert_eq!(kernel.edit_distance("cat", "cat"), 0);
862 assert_eq!(kernel.edit_distance("cat", "bat"), 1);
863 assert_eq!(kernel.edit_distance("cat", "dog"), 3);
864 assert_eq!(kernel.edit_distance("kitten", "sitting"), 3);
865 }
866
867 #[test]
868 fn test_ngram_binary_mode() {
869 let kernel = NGramKernel::new(2).binary(true).normalize(false);
870
871 let sequences = vec![
872 "aaa".to_string(), "aba".to_string(), ];
875
876 let fitted = kernel.fit(&sequences, &()).unwrap();
877 let features = fitted.transform(&sequences).unwrap();
878
879 assert!(features.iter().all(|&x| x == 0.0 || x == 1.0));
881 }
882}