sklears_kernel_approximation/
string_kernels.rs

1//! String Kernel Approximations
2//!
3//! This module implements various string kernel approximation methods for
4//! sequence and text analysis. String kernels measure similarity between
5//! sequences of symbols (characters, words, etc.) by counting shared
6//! subsequences or n-grams.
7//!
8//! # Key Features
9//!
10//! - **N-gram Kernels**: Count shared n-grams between sequences
11//! - **Spectrum Kernels**: Fixed-length contiguous substring kernels
12//! - **Subsequence Kernels**: Count all shared subsequences with gaps
13//! - **Edit Distance Approximations**: Approximate edit distance kernels
14//! - **Mismatch Kernels**: Allow for mismatches in n-gram comparisons
15//! - **Weighted Subsequence Kernels**: Weight subsequences by length and gaps
16//!
17//! # Mathematical Background
18//!
19//! String kernel between sequences s and t:
20//! K(s, t) = Σ φ(s)[u] * φ(t)[u]
21//!
22//! Where φ(s)[u] is the feature map that counts occurrences of substring u.
23//!
24//! # References
25//!
26//! - Shawe-Taylor, J., & Cristianini, N. (2004). Kernel methods for pattern analysis
27//! - Lodhi, H., et al. (2002). Text classification using string kernels
28
29use scirs2_core::ndarray::Array2;
30use sklears_core::{
31    error::Result,
32    prelude::{Fit, Transform},
33};
34use std::collections::HashMap;
35
36/// N-gram kernel for sequences
37#[derive(Debug, Clone)]
38/// NGramKernel
39pub struct NGramKernel {
40    /// N-gram size
41    n: usize,
42    /// Whether to normalize features
43    normalize: bool,
44    /// Whether to use binary features (presence/absence vs counts)
45    binary: bool,
46    /// Character-level vs word-level n-grams
47    mode: NGramMode,
48}
49
50/// N-gram extraction mode
51#[derive(Debug, Clone)]
52/// NGramMode
53pub enum NGramMode {
54    Character,
55    Word,
56    Custom { delimiter: String },
57}
58
59/// Fitted n-gram kernel
60#[derive(Debug, Clone)]
61/// FittedNGramKernel
62pub struct FittedNGramKernel {
63    /// Vocabulary mapping from n-gram to index
64    vocabulary: HashMap<String, usize>,
65    /// N-gram size
66    n: usize,
67    /// Normalization flag
68    normalize: bool,
69    /// Binary flag
70    binary: bool,
71    /// N-gram mode
72    mode: NGramMode,
73}
74
75impl NGramKernel {
76    /// Create new n-gram kernel
77    pub fn new(n: usize) -> Self {
78        Self {
79            n,
80            normalize: true,
81            binary: false,
82            mode: NGramMode::Character,
83        }
84    }
85
86    /// Set normalization
87    pub fn normalize(mut self, normalize: bool) -> Self {
88        self.normalize = normalize;
89        self
90    }
91
92    /// Set binary mode
93    pub fn binary(mut self, binary: bool) -> Self {
94        self.binary = binary;
95        self
96    }
97
98    /// Set n-gram mode
99    pub fn mode(mut self, mode: NGramMode) -> Self {
100        self.mode = mode;
101        self
102    }
103
104    /// Extract n-grams from a sequence
105    fn extract_ngrams(&self, sequence: &str) -> Vec<String> {
106        match &self.mode {
107            NGramMode::Character => {
108                let chars: Vec<char> = sequence.chars().collect();
109                chars
110                    .windows(self.n)
111                    .map(|window| window.iter().collect())
112                    .collect()
113            }
114            NGramMode::Word => {
115                let words: Vec<&str> = sequence.split_whitespace().collect();
116                words
117                    .windows(self.n)
118                    .map(|window| window.join(" "))
119                    .collect()
120            }
121            NGramMode::Custom { delimiter } => {
122                let tokens: Vec<&str> = sequence.split(delimiter).collect();
123                tokens
124                    .windows(self.n)
125                    .map(|window| window.join(delimiter))
126                    .collect()
127            }
128        }
129    }
130}
131
132impl Fit<Vec<String>, ()> for NGramKernel {
133    type Fitted = FittedNGramKernel;
134
135    fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
136        let mut vocabulary = HashMap::new();
137        let mut vocab_index = 0;
138
139        // Build vocabulary from all sequences
140        for sequence in sequences {
141            let ngrams = self.extract_ngrams(sequence);
142            for ngram in ngrams {
143                if let std::collections::hash_map::Entry::Vacant(e) = vocabulary.entry(ngram) {
144                    e.insert(vocab_index);
145                    vocab_index += 1;
146                }
147            }
148        }
149
150        Ok(FittedNGramKernel {
151            vocabulary,
152            n: self.n,
153            normalize: self.normalize,
154            binary: self.binary,
155            mode: self.mode.clone(),
156        })
157    }
158}
159
160impl Transform<Vec<String>, Array2<f64>> for FittedNGramKernel {
161    fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
162        let n_sequences = sequences.len();
163        let vocab_size = self.vocabulary.len();
164        let mut features = Array2::zeros((n_sequences, vocab_size));
165
166        for (i, sequence) in sequences.iter().enumerate() {
167            let ngrams = match &self.mode {
168                NGramMode::Character => {
169                    let chars: Vec<char> = sequence.chars().collect();
170                    chars
171                        .windows(self.n)
172                        .map(|window| window.iter().collect::<String>())
173                        .collect::<Vec<String>>()
174                }
175                NGramMode::Word => {
176                    let words: Vec<&str> = sequence.split_whitespace().collect();
177                    words
178                        .windows(self.n)
179                        .map(|window| window.join(" "))
180                        .collect::<Vec<String>>()
181                }
182                NGramMode::Custom { delimiter } => {
183                    let tokens: Vec<&str> = sequence.split(delimiter).collect();
184                    tokens
185                        .windows(self.n)
186                        .map(|window| window.join(delimiter))
187                        .collect::<Vec<String>>()
188                }
189            };
190
191            // Count n-grams
192            let mut ngram_counts = HashMap::new();
193            for ngram in ngrams {
194                if let Some(&vocab_idx) = self.vocabulary.get(&ngram) {
195                    *ngram_counts.entry(vocab_idx).or_insert(0) += 1;
196                }
197            }
198
199            // Fill feature vector
200            for (vocab_idx, count) in ngram_counts {
201                features[(i, vocab_idx)] = if self.binary { 1.0 } else { count as f64 };
202            }
203
204            // Normalize if requested
205            if self.normalize {
206                let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
207                if norm > 0.0 {
208                    for j in 0..vocab_size {
209                        features[(i, j)] /= norm;
210                    }
211                }
212            }
213        }
214
215        Ok(features)
216    }
217}
218
219/// Spectrum kernel for fixed-length contiguous substrings
220#[derive(Debug, Clone)]
221/// SpectrumKernel
222pub struct SpectrumKernel {
223    /// Substring length (k-mer size)
224    k: usize,
225    /// Whether to normalize features
226    normalize: bool,
227}
228
229impl SpectrumKernel {
230    /// Create new spectrum kernel
231    pub fn new(k: usize) -> Self {
232        Self { k, normalize: true }
233    }
234
235    /// Set normalization
236    pub fn normalize(mut self, normalize: bool) -> Self {
237        self.normalize = normalize;
238        self
239    }
240}
241
242/// Fitted spectrum kernel
243#[derive(Debug, Clone)]
244/// FittedSpectrumKernel
245pub struct FittedSpectrumKernel {
246    /// Vocabulary of k-mers
247    vocabulary: HashMap<String, usize>,
248    /// K-mer length
249    k: usize,
250    /// Normalization flag
251    normalize: bool,
252}
253
254impl Fit<Vec<String>, ()> for SpectrumKernel {
255    type Fitted = FittedSpectrumKernel;
256
257    fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
258        let mut vocabulary = HashMap::new();
259        let mut vocab_index = 0;
260
261        // Build vocabulary of all k-mers
262        for sequence in sequences {
263            let chars: Vec<char> = sequence.chars().collect();
264            for window in chars.windows(self.k) {
265                let kmer: String = window.iter().collect();
266                if let std::collections::hash_map::Entry::Vacant(e) = vocabulary.entry(kmer) {
267                    e.insert(vocab_index);
268                    vocab_index += 1;
269                }
270            }
271        }
272
273        Ok(FittedSpectrumKernel {
274            vocabulary,
275            k: self.k,
276            normalize: self.normalize,
277        })
278    }
279}
280
281impl Transform<Vec<String>, Array2<f64>> for FittedSpectrumKernel {
282    fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
283        let n_sequences = sequences.len();
284        let vocab_size = self.vocabulary.len();
285        let mut features = Array2::zeros((n_sequences, vocab_size));
286
287        for (i, sequence) in sequences.iter().enumerate() {
288            let chars: Vec<char> = sequence.chars().collect();
289            let mut kmer_counts = HashMap::new();
290
291            // Count k-mers
292            for window in chars.windows(self.k) {
293                let kmer: String = window.iter().collect();
294                if let Some(&vocab_idx) = self.vocabulary.get(&kmer) {
295                    *kmer_counts.entry(vocab_idx).or_insert(0) += 1;
296                }
297            }
298
299            // Fill feature vector
300            for (vocab_idx, count) in kmer_counts {
301                features[(i, vocab_idx)] = count as f64;
302            }
303
304            // Normalize if requested
305            if self.normalize {
306                let norm = features.row(i).mapv(|x| x * x).sum().sqrt();
307                if norm > 0.0 {
308                    for j in 0..vocab_size {
309                        features[(i, j)] /= norm;
310                    }
311                }
312            }
313        }
314
315        Ok(features)
316    }
317}
318
319/// Subsequence kernel that counts all shared subsequences (with gaps)
320#[derive(Debug, Clone)]
321/// SubsequenceKernel
322pub struct SubsequenceKernel {
323    /// Maximum subsequence length
324    max_length: usize,
325    /// Gap penalty (lambda parameter, 0 < lambda <= 1)
326    gap_penalty: f64,
327    /// Normalize features
328    normalize: bool,
329}
330
331impl SubsequenceKernel {
332    /// Create new subsequence kernel
333    pub fn new(max_length: usize, gap_penalty: f64) -> Self {
334        Self {
335            max_length,
336            gap_penalty,
337            normalize: true,
338        }
339    }
340
341    /// Set normalization
342    pub fn normalize(mut self, normalize: bool) -> Self {
343        self.normalize = normalize;
344        self
345    }
346
347    /// Compute subsequence kernel between two sequences using dynamic programming
348    fn subsequence_kernel_value(&self, s1: &str, s2: &str) -> f64 {
349        let chars1: Vec<char> = s1.chars().collect();
350        let chars2: Vec<char> = s2.chars().collect();
351        let n1 = chars1.len();
352        let n2 = chars2.len();
353
354        if n1 == 0 || n2 == 0 {
355            return 0.0;
356        }
357
358        let mut dp = vec![vec![vec![0.0; n2 + 1]; n1 + 1]; self.max_length + 1];
359
360        // Initialize base cases
361        for i in 0..=n1 {
362            for j in 0..=n2 {
363                dp[0][i][j] = 1.0;
364            }
365        }
366
367        // Fill DP table
368        for k in 1..=self.max_length {
369            for i in 1..=n1 {
370                for j in 1..=n2 {
371                    // Case 1: don't include chars1[i-1]
372                    dp[k][i][j] = self.gap_penalty * dp[k][i - 1][j];
373
374                    // Case 2: include chars1[i-1] if it matches chars2[j-1]
375                    if chars1[i - 1] == chars2[j - 1] {
376                        dp[k][i][j] += self.gap_penalty * dp[k - 1][i - 1][j - 1];
377                    }
378
379                    // Case 3: don't include chars2[j-1]
380                    dp[k][i][j] += self.gap_penalty * dp[k][i][j - 1];
381
382                    // Case 4: don't include both
383                    if chars1[i - 1] == chars2[j - 1] {
384                        dp[k][i][j] -=
385                            self.gap_penalty * self.gap_penalty * dp[k - 1][i - 1][j - 1];
386                    }
387                }
388            }
389        }
390
391        // Sum over all subsequence lengths
392        let mut total = 0.0;
393        for k in 1..=self.max_length {
394            total += dp[k][n1][n2];
395        }
396
397        total
398    }
399}
400
401/// Fitted subsequence kernel (computes full kernel matrix)
402#[derive(Debug, Clone)]
403/// FittedSubsequenceKernel
404pub struct FittedSubsequenceKernel {
405    /// Training sequences
406    training_sequences: Vec<String>,
407    /// Maximum subsequence length
408    max_length: usize,
409    /// Gap penalty
410    gap_penalty: f64,
411    /// Normalization flag
412    normalize: bool,
413}
414
415impl Fit<Vec<String>, ()> for SubsequenceKernel {
416    type Fitted = FittedSubsequenceKernel;
417
418    fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
419        Ok(FittedSubsequenceKernel {
420            training_sequences: sequences.clone(),
421            max_length: self.max_length,
422            gap_penalty: self.gap_penalty,
423            normalize: self.normalize,
424        })
425    }
426}
427
428impl Transform<Vec<String>, Array2<f64>> for FittedSubsequenceKernel {
429    fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
430        let n_test = sequences.len();
431        let n_train = self.training_sequences.len();
432        let mut kernel_matrix = Array2::zeros((n_test, n_train));
433
434        // Temporary kernel instance for computation
435        let kernel = SubsequenceKernel {
436            max_length: self.max_length,
437            gap_penalty: self.gap_penalty,
438            normalize: false, // Handle normalization separately
439        };
440
441        for i in 0..n_test {
442            for j in 0..n_train {
443                kernel_matrix[(i, j)] =
444                    kernel.subsequence_kernel_value(&sequences[i], &self.training_sequences[j]);
445            }
446
447            // Normalize row if requested
448            if self.normalize {
449                let norm = kernel_matrix.row(i).mapv(|x| x * x).sum().sqrt();
450                if norm > 0.0 {
451                    for j in 0..n_train {
452                        kernel_matrix[(i, j)] /= norm;
453                    }
454                }
455            }
456        }
457
458        Ok(kernel_matrix)
459    }
460}
461
462/// Edit distance approximation kernel
463#[derive(Debug, Clone)]
464/// EditDistanceKernel
465pub struct EditDistanceKernel {
466    /// Maximum edit distance to consider
467    max_distance: usize,
468    /// Kernel bandwidth parameter
469    sigma: f64,
470}
471
472impl EditDistanceKernel {
473    /// Create new edit distance kernel
474    pub fn new(max_distance: usize, sigma: f64) -> Self {
475        Self {
476            max_distance,
477            sigma,
478        }
479    }
480
481    /// Compute edit distance between two strings
482    fn edit_distance(&self, s1: &str, s2: &str) -> usize {
483        let chars1: Vec<char> = s1.chars().collect();
484        let chars2: Vec<char> = s2.chars().collect();
485        let n1 = chars1.len();
486        let n2 = chars2.len();
487
488        let mut dp = vec![vec![0; n2 + 1]; n1 + 1];
489
490        // Initialize first row and column
491        for i in 0..=n1 {
492            dp[i][0] = i;
493        }
494        for j in 0..=n2 {
495            dp[0][j] = j;
496        }
497
498        // Fill DP table
499        for i in 1..=n1 {
500            for j in 1..=n2 {
501                let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
502                dp[i][j] = (dp[i - 1][j] + 1)
503                    .min(dp[i][j - 1] + 1)
504                    .min(dp[i - 1][j - 1] + cost);
505            }
506        }
507
508        dp[n1][n2]
509    }
510
511    /// Compute kernel value from edit distance
512    fn kernel_value(&self, s1: &str, s2: &str) -> f64 {
513        let distance = self.edit_distance(s1, s2);
514        if distance > self.max_distance {
515            0.0
516        } else {
517            (-(distance as f64) / self.sigma).exp()
518        }
519    }
520}
521
522/// Fitted edit distance kernel
523#[derive(Debug, Clone)]
524/// FittedEditDistanceKernel
525pub struct FittedEditDistanceKernel {
526    /// Training sequences
527    training_sequences: Vec<String>,
528    /// Maximum distance
529    max_distance: usize,
530    /// Sigma parameter
531    sigma: f64,
532}
533
534impl Fit<Vec<String>, ()> for EditDistanceKernel {
535    type Fitted = FittedEditDistanceKernel;
536
537    fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
538        Ok(FittedEditDistanceKernel {
539            training_sequences: sequences.clone(),
540            max_distance: self.max_distance,
541            sigma: self.sigma,
542        })
543    }
544}
545
546impl Transform<Vec<String>, Array2<f64>> for FittedEditDistanceKernel {
547    fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
548        let n_test = sequences.len();
549        let n_train = self.training_sequences.len();
550        let mut kernel_matrix = Array2::zeros((n_test, n_train));
551
552        let kernel = EditDistanceKernel {
553            max_distance: self.max_distance,
554            sigma: self.sigma,
555        };
556
557        for i in 0..n_test {
558            for j in 0..n_train {
559                kernel_matrix[(i, j)] =
560                    kernel.kernel_value(&sequences[i], &self.training_sequences[j]);
561            }
562        }
563
564        Ok(kernel_matrix)
565    }
566}
567
568/// Mismatch kernel that allows k mismatches in n-grams
569#[derive(Debug, Clone)]
570/// MismatchKernel
571pub struct MismatchKernel {
572    /// N-gram length
573    k: usize,
574    /// Number of allowed mismatches
575    m: usize,
576    /// Alphabet (set of allowed characters)
577    alphabet: Vec<char>,
578}
579
580impl MismatchKernel {
581    /// Create new mismatch kernel
582    pub fn new(k: usize, m: usize) -> Self {
583        // Default DNA alphabet
584        let alphabet = vec!['A', 'C', 'G', 'T'];
585        Self { k, m, alphabet }
586    }
587
588    /// Set custom alphabet
589    pub fn alphabet(mut self, alphabet: Vec<char>) -> Self {
590        self.alphabet = alphabet;
591        self
592    }
593
594    /// Generate all possible k-mers with up to m mismatches from a given k-mer
595    fn generate_neighborhood(&self, kmer: &str, mismatches: usize) -> Vec<String> {
596        if mismatches == 0 {
597            return vec![kmer.to_string()];
598        }
599
600        let chars: Vec<char> = kmer.chars().collect();
601        let mut neighborhood = Vec::new();
602
603        // Generate all combinations with exactly 'mismatches' positions changed
604        self.generate_mismatches(&chars, 0, mismatches, &mut vec![], &mut neighborhood);
605
606        neighborhood
607    }
608
609    /// Recursive helper for generating mismatches
610    fn generate_mismatches(
611        &self,
612        original: &[char],
613        pos: usize,
614        mismatches_left: usize,
615        current: &mut Vec<char>,
616        result: &mut Vec<String>,
617    ) {
618        if pos == original.len() {
619            if mismatches_left == 0 {
620                result.push(current.iter().collect());
621            }
622            return;
623        }
624
625        // Option 1: Keep original character
626        current.push(original[pos]);
627        self.generate_mismatches(original, pos + 1, mismatches_left, current, result);
628        current.pop();
629
630        // Option 2: Try all alphabet characters (if we still have mismatches to use)
631        if mismatches_left > 0 {
632            for &c in &self.alphabet {
633                if c != original[pos] {
634                    current.push(c);
635                    self.generate_mismatches(
636                        original,
637                        pos + 1,
638                        mismatches_left - 1,
639                        current,
640                        result,
641                    );
642                    current.pop();
643                }
644            }
645        }
646    }
647}
648
649/// Fitted mismatch kernel
650#[derive(Debug, Clone)]
651/// FittedMismatchKernel
652pub struct FittedMismatchKernel {
653    /// Feature vocabulary
654    vocabulary: HashMap<String, usize>,
655    /// K-mer length
656    k: usize,
657    /// Number of mismatches
658    m: usize,
659    /// Alphabet
660    alphabet: Vec<char>,
661}
662
663impl Fit<Vec<String>, ()> for MismatchKernel {
664    type Fitted = FittedMismatchKernel;
665
666    fn fit(self, sequences: &Vec<String>, _y: &()) -> Result<Self::Fitted> {
667        let mut vocabulary = HashMap::new();
668        let mut vocab_index = 0;
669
670        // Build vocabulary of all possible k-mers with mismatches
671        for sequence in sequences {
672            let chars: Vec<char> = sequence.chars().collect();
673            for window in chars.windows(self.k) {
674                let kmer: String = window.iter().collect();
675
676                // Generate neighborhood with up to m mismatches
677                for mismatch_count in 0..=self.m {
678                    let neighborhood = self.generate_neighborhood(&kmer, mismatch_count);
679                    for neighbor in neighborhood {
680                        if let std::collections::hash_map::Entry::Vacant(e) =
681                            vocabulary.entry(neighbor)
682                        {
683                            e.insert(vocab_index);
684                            vocab_index += 1;
685                        }
686                    }
687                }
688            }
689        }
690
691        Ok(FittedMismatchKernel {
692            vocabulary,
693            k: self.k,
694            m: self.m,
695            alphabet: self.alphabet.clone(),
696        })
697    }
698}
699
700impl Transform<Vec<String>, Array2<f64>> for FittedMismatchKernel {
701    fn transform(&self, sequences: &Vec<String>) -> Result<Array2<f64>> {
702        let n_sequences = sequences.len();
703        let vocab_size = self.vocabulary.len();
704        let mut features = Array2::zeros((n_sequences, vocab_size));
705
706        let kernel = MismatchKernel {
707            k: self.k,
708            m: self.m,
709            alphabet: self.alphabet.clone(),
710        };
711
712        for (i, sequence) in sequences.iter().enumerate() {
713            let chars: Vec<char> = sequence.chars().collect();
714            let mut feature_counts = HashMap::new();
715
716            // Extract k-mers and generate their neighborhoods
717            for window in chars.windows(self.k) {
718                let kmer: String = window.iter().collect();
719
720                for mismatch_count in 0..=self.m {
721                    let neighborhood = kernel.generate_neighborhood(&kmer, mismatch_count);
722                    for neighbor in neighborhood {
723                        if let Some(&vocab_idx) = self.vocabulary.get(&neighbor) {
724                            *feature_counts.entry(vocab_idx).or_insert(0) += 1;
725                        }
726                    }
727                }
728            }
729
730            // Fill feature vector
731            for (vocab_idx, count) in feature_counts {
732                features[(i, vocab_idx)] = count as f64;
733            }
734        }
735
736        Ok(features)
737    }
738}
739
740#[allow(non_snake_case)]
741#[cfg(test)]
742mod tests {
743    use super::*;
744    use approx::assert_abs_diff_eq;
745
746    #[test]
747    fn test_ngram_kernel_character() {
748        let kernel = NGramKernel::new(2).mode(NGramMode::Character);
749        let sequences = vec!["hello".to_string(), "world".to_string(), "help".to_string()];
750
751        let fitted = kernel.fit(&sequences, &()).unwrap();
752        let features = fitted.transform(&sequences).unwrap();
753
754        assert_eq!(features.nrows(), 3);
755        assert!(features.ncols() > 0);
756        assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
757    }
758
759    #[test]
760    fn test_ngram_kernel_word() {
761        let kernel = NGramKernel::new(2).mode(NGramMode::Word);
762        let sequences = vec![
763            "hello world".to_string(),
764            "world peace".to_string(),
765            "hello there".to_string(),
766        ];
767
768        let fitted = kernel.fit(&sequences, &()).unwrap();
769        let features = fitted.transform(&sequences).unwrap();
770
771        assert_eq!(features.nrows(), 3);
772        assert!(features.ncols() > 0);
773        assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
774    }
775
776    #[test]
777    fn test_spectrum_kernel() {
778        let kernel = SpectrumKernel::new(3);
779        let sequences = vec![
780            "ATCGATCG".to_string(),
781            "GCTAGCTA".to_string(),
782            "ATCGATCG".to_string(), // duplicate
783        ];
784
785        let fitted = kernel.fit(&sequences, &()).unwrap();
786        let features = fitted.transform(&sequences).unwrap();
787
788        assert_eq!(features.nrows(), 3);
789        assert!(features.ncols() > 0);
790
791        // First and third sequences should be identical
792        for j in 0..features.ncols() {
793            assert_abs_diff_eq!(features[(0, j)], features[(2, j)], epsilon = 1e-10);
794        }
795    }
796
797    #[test]
798    fn test_subsequence_kernel() {
799        let kernel = SubsequenceKernel::new(3, 0.5);
800        let sequences = vec!["ABC".to_string(), "ACB".to_string(), "ABC".to_string()];
801
802        let fitted = kernel.fit(&sequences, &()).unwrap();
803        let features = fitted.transform(&sequences).unwrap();
804
805        assert_eq!(features.nrows(), 3);
806        assert_eq!(features.ncols(), 3);
807        assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
808
809        // Kernel should be symmetric for identical sequences
810        assert!(features[(0, 0)] > 0.0);
811        assert_abs_diff_eq!(features[(0, 0)], features[(2, 0)], epsilon = 1e-10);
812    }
813
814    #[test]
815    fn test_edit_distance_kernel() {
816        let kernel = EditDistanceKernel::new(5, 1.0);
817        let sequences = vec![
818            "cat".to_string(),
819            "bat".to_string(),
820            "rat".to_string(),
821            "dog".to_string(),
822        ];
823
824        let fitted = kernel.fit(&sequences, &()).unwrap();
825        let features = fitted.transform(&sequences).unwrap();
826
827        assert_eq!(features.nrows(), 4);
828        assert_eq!(features.ncols(), 4);
829        assert!(features
830            .iter()
831            .all(|&x| x >= 0.0 && x <= 1.0 && x.is_finite()));
832
833        // Self-similarity should be 1.0
834        for i in 0..4 {
835            assert_abs_diff_eq!(features[(i, i)], 1.0, epsilon = 1e-10);
836        }
837    }
838
839    #[test]
840    fn test_mismatch_kernel() {
841        let kernel = MismatchKernel::new(3, 1).alphabet(vec!['A', 'C', 'G', 'T']);
842        let sequences = vec![
843            "ATCG".to_string(),
844            "ATCC".to_string(), // 1 mismatch from first
845            "GCTA".to_string(),
846        ];
847
848        let fitted = kernel.fit(&sequences, &()).unwrap();
849        let features = fitted.transform(&sequences).unwrap();
850
851        assert_eq!(features.nrows(), 3);
852        assert!(features.ncols() > 0);
853        assert!(features.iter().all(|&x| x >= 0.0 && x.is_finite()));
854    }
855
856    #[test]
857    fn test_edit_distance_computation() {
858        let kernel = EditDistanceKernel::new(10, 1.0);
859
860        assert_eq!(kernel.edit_distance("", ""), 0);
861        assert_eq!(kernel.edit_distance("cat", "cat"), 0);
862        assert_eq!(kernel.edit_distance("cat", "bat"), 1);
863        assert_eq!(kernel.edit_distance("cat", "dog"), 3);
864        assert_eq!(kernel.edit_distance("kitten", "sitting"), 3);
865    }
866
867    #[test]
868    fn test_ngram_binary_mode() {
869        let kernel = NGramKernel::new(2).binary(true).normalize(false);
870
871        let sequences = vec![
872            "aaa".to_string(), // "aa" appears twice
873            "aba".to_string(), // "ab" and "ba" appear once each
874        ];
875
876        let fitted = kernel.fit(&sequences, &()).unwrap();
877        let features = fitted.transform(&sequences).unwrap();
878
879        // In binary mode, repeated n-grams should only count as 1
880        assert!(features.iter().all(|&x| x == 0.0 || x == 1.0));
881    }
882}