Skip to main content

trustformers_tokenizers/
bio.rs

1//! Biological sequence tokenizer for TrustformeRS
2//!
3//! This module provides specialized tokenization for biological sequences
4//! including DNA, RNA, and protein sequences used in bioinformatics.
5
6use once_cell::sync::Lazy;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use trustformers_core::errors::Result;
11use trustformers_core::traits::{TokenizedInput, Tokenizer};
12
13/// Configuration for biological sequence tokenizer
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BioTokenizerConfig {
16    /// Maximum sequence length
17    pub max_length: Option<usize>,
18    /// Whether to include special bio tokens
19    pub include_special_tokens: bool,
20    /// Whether to tokenize DNA sequences
21    pub tokenize_dna: bool,
22    /// Whether to tokenize RNA sequences
23    pub tokenize_rna: bool,
24    /// Whether to tokenize protein sequences
25    pub tokenize_proteins: bool,
26    /// K-mer size for subsequence tokenization
27    pub kmer_size: Option<usize>,
28    /// Whether to use overlapping k-mers
29    pub overlapping_kmers: bool,
30    /// Whether to preserve case (for mixed case sequences)
31    pub preserve_case: bool,
32    /// Whether to handle ambiguous nucleotides/amino acids
33    pub handle_ambiguous: bool,
34    /// Whether to tokenize secondary structure annotations
35    pub tokenize_structure: bool,
36    /// Vocabulary size limit
37    pub vocab_size: Option<usize>,
38}
39
40impl Default for BioTokenizerConfig {
41    fn default() -> Self {
42        Self {
43            max_length: Some(2048),
44            include_special_tokens: true,
45            tokenize_dna: true,
46            tokenize_rna: true,
47            tokenize_proteins: true,
48            kmer_size: Some(3), // Codons for DNA/RNA
49            overlapping_kmers: true,
50            preserve_case: false,
51            handle_ambiguous: true,
52            tokenize_structure: false,
53            vocab_size: Some(5000),
54        }
55    }
56}
57
58/// Types of biological tokens
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum BioTokenType {
61    /// DNA nucleotides (A, T, G, C)
62    DNANucleotide,
63    /// RNA nucleotides (A, U, G, C)
64    RNANucleotide,
65    /// Amino acids (20 standard + special)
66    AminoAcid,
67    /// K-mer sequences
68    Kmer,
69    /// Ambiguous nucleotides (N, R, Y, etc.)
70    AmbiguousNucleotide,
71    /// Ambiguous amino acids (X, B, Z)
72    AmbiguousAminoAcid,
73    /// Stop codons
74    StopCodon,
75    /// Start codons
76    StartCodon,
77    /// Secondary structure elements
78    SecondaryStructure,
79    /// Sequence modifications
80    Modification,
81    /// Special sequence markers
82    Special,
83    /// Unknown sequences
84    Unknown,
85}
86
87/// Biological token with metadata
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct BioToken {
90    /// Token text
91    pub text: String,
92    /// Token type
93    pub token_type: BioTokenType,
94    /// Start position in original sequence
95    pub start: usize,
96    /// End position in original sequence
97    pub end: usize,
98    /// Biological metadata
99    pub metadata: Option<BioTokenMetadata>,
100}
101
102/// Metadata for biological tokens
103#[derive(Debug, Clone, Serialize, Deserialize, Default)]
104pub struct BioTokenMetadata {
105    /// Molecular weight (for amino acids)
106    pub molecular_weight: Option<f64>,
107    /// Hydrophobicity (for amino acids)
108    pub hydrophobicity: Option<f64>,
109    /// Charge (for amino acids)
110    pub charge: Option<i8>,
111    /// GC content (for nucleotide sequences)
112    pub gc_content: Option<f64>,
113    /// Melting temperature (for DNA/RNA)
114    pub melting_temp: Option<f64>,
115    /// Codon table position
116    pub codon_position: Option<u8>,
117    /// Reading frame
118    pub reading_frame: Option<u8>,
119    /// Secondary structure type
120    pub structure_type: Option<String>,
121}
122
123/// Biological sequence tokenizer
124pub struct BioTokenizer {
125    config: BioTokenizerConfig,
126    vocab: HashMap<String, u32>,
127    id_to_token: HashMap<u32, String>,
128    next_id: u32,
129    amino_acids: HashMap<char, AminoAcidInfo>,
130    nucleotides: HashMap<char, NucleotideInfo>,
131    genetic_code: HashMap<String, char>,
132    #[allow(dead_code)]
133    structure_patterns: Vec<Regex>,
134}
135
136/// Amino acid information
137#[derive(Debug, Clone)]
138struct AminoAcidInfo {
139    #[allow(dead_code)]
140    name: String,
141    molecular_weight: f64,
142    hydrophobicity: f64,
143    charge: i8,
144    #[allow(dead_code)]
145    single_letter: char,
146    #[allow(dead_code)]
147    three_letter: String,
148}
149
150/// Nucleotide information
151#[derive(Debug, Clone)]
152struct NucleotideInfo {
153    #[allow(dead_code)]
154    name: String,
155    complement: char,
156    #[allow(dead_code)]
157    is_purine: bool,
158    molecular_weight: f64,
159}
160
161// Static data for amino acids
162static AMINO_ACIDS: Lazy<HashMap<char, AminoAcidInfo>> = Lazy::new(|| {
163    let mut map = HashMap::new();
164
165    // Standard amino acids with properties
166    map.insert(
167        'A',
168        AminoAcidInfo {
169            name: "Alanine".to_string(),
170            molecular_weight: 89.1,
171            hydrophobicity: 1.8,
172            charge: 0,
173            single_letter: 'A',
174            three_letter: "Ala".to_string(),
175        },
176    );
177    map.insert(
178        'R',
179        AminoAcidInfo {
180            name: "Arginine".to_string(),
181            molecular_weight: 174.2,
182            hydrophobicity: -4.5,
183            charge: 1,
184            single_letter: 'R',
185            three_letter: "Arg".to_string(),
186        },
187    );
188    map.insert(
189        'N',
190        AminoAcidInfo {
191            name: "Asparagine".to_string(),
192            molecular_weight: 132.1,
193            hydrophobicity: -3.5,
194            charge: 0,
195            single_letter: 'N',
196            three_letter: "Asn".to_string(),
197        },
198    );
199    map.insert(
200        'D',
201        AminoAcidInfo {
202            name: "Aspartic acid".to_string(),
203            molecular_weight: 133.1,
204            hydrophobicity: -3.5,
205            charge: -1,
206            single_letter: 'D',
207            three_letter: "Asp".to_string(),
208        },
209    );
210    map.insert(
211        'C',
212        AminoAcidInfo {
213            name: "Cysteine".to_string(),
214            molecular_weight: 121.0,
215            hydrophobicity: 2.5,
216            charge: 0,
217            single_letter: 'C',
218            three_letter: "Cys".to_string(),
219        },
220    );
221    map.insert(
222        'E',
223        AminoAcidInfo {
224            name: "Glutamic acid".to_string(),
225            molecular_weight: 147.1,
226            hydrophobicity: -3.5,
227            charge: -1,
228            single_letter: 'E',
229            three_letter: "Glu".to_string(),
230        },
231    );
232    map.insert(
233        'Q',
234        AminoAcidInfo {
235            name: "Glutamine".to_string(),
236            molecular_weight: 146.1,
237            hydrophobicity: -3.5,
238            charge: 0,
239            single_letter: 'Q',
240            three_letter: "Gln".to_string(),
241        },
242    );
243    map.insert(
244        'G',
245        AminoAcidInfo {
246            name: "Glycine".to_string(),
247            molecular_weight: 75.1,
248            hydrophobicity: -0.4,
249            charge: 0,
250            single_letter: 'G',
251            three_letter: "Gly".to_string(),
252        },
253    );
254    map.insert(
255        'H',
256        AminoAcidInfo {
257            name: "Histidine".to_string(),
258            molecular_weight: 155.2,
259            hydrophobicity: -3.2,
260            charge: 0,
261            single_letter: 'H',
262            three_letter: "His".to_string(),
263        },
264    );
265    map.insert(
266        'I',
267        AminoAcidInfo {
268            name: "Isoleucine".to_string(),
269            molecular_weight: 131.2,
270            hydrophobicity: 4.5,
271            charge: 0,
272            single_letter: 'I',
273            three_letter: "Ile".to_string(),
274        },
275    );
276    map.insert(
277        'L',
278        AminoAcidInfo {
279            name: "Leucine".to_string(),
280            molecular_weight: 131.2,
281            hydrophobicity: 3.8,
282            charge: 0,
283            single_letter: 'L',
284            three_letter: "Leu".to_string(),
285        },
286    );
287    map.insert(
288        'K',
289        AminoAcidInfo {
290            name: "Lysine".to_string(),
291            molecular_weight: 146.2,
292            hydrophobicity: -3.9,
293            charge: 1,
294            single_letter: 'K',
295            three_letter: "Lys".to_string(),
296        },
297    );
298    map.insert(
299        'M',
300        AminoAcidInfo {
301            name: "Methionine".to_string(),
302            molecular_weight: 149.2,
303            hydrophobicity: 1.9,
304            charge: 0,
305            single_letter: 'M',
306            three_letter: "Met".to_string(),
307        },
308    );
309    map.insert(
310        'F',
311        AminoAcidInfo {
312            name: "Phenylalanine".to_string(),
313            molecular_weight: 165.2,
314            hydrophobicity: 2.8,
315            charge: 0,
316            single_letter: 'F',
317            three_letter: "Phe".to_string(),
318        },
319    );
320    map.insert(
321        'P',
322        AminoAcidInfo {
323            name: "Proline".to_string(),
324            molecular_weight: 115.1,
325            hydrophobicity: -1.6,
326            charge: 0,
327            single_letter: 'P',
328            three_letter: "Pro".to_string(),
329        },
330    );
331    map.insert(
332        'S',
333        AminoAcidInfo {
334            name: "Serine".to_string(),
335            molecular_weight: 105.1,
336            hydrophobicity: -0.8,
337            charge: 0,
338            single_letter: 'S',
339            three_letter: "Ser".to_string(),
340        },
341    );
342    map.insert(
343        'T',
344        AminoAcidInfo {
345            name: "Threonine".to_string(),
346            molecular_weight: 119.1,
347            hydrophobicity: -0.7,
348            charge: 0,
349            single_letter: 'T',
350            three_letter: "Thr".to_string(),
351        },
352    );
353    map.insert(
354        'W',
355        AminoAcidInfo {
356            name: "Tryptophan".to_string(),
357            molecular_weight: 204.2,
358            hydrophobicity: -0.9,
359            charge: 0,
360            single_letter: 'W',
361            three_letter: "Trp".to_string(),
362        },
363    );
364    map.insert(
365        'Y',
366        AminoAcidInfo {
367            name: "Tyrosine".to_string(),
368            molecular_weight: 181.2,
369            hydrophobicity: -1.3,
370            charge: 0,
371            single_letter: 'Y',
372            three_letter: "Tyr".to_string(),
373        },
374    );
375    map.insert(
376        'V',
377        AminoAcidInfo {
378            name: "Valine".to_string(),
379            molecular_weight: 117.1,
380            hydrophobicity: 4.2,
381            charge: 0,
382            single_letter: 'V',
383            three_letter: "Val".to_string(),
384        },
385    );
386
387    // Ambiguous amino acids
388    map.insert(
389        'X',
390        AminoAcidInfo {
391            name: "Unknown".to_string(),
392            molecular_weight: 0.0,
393            hydrophobicity: 0.0,
394            charge: 0,
395            single_letter: 'X',
396            three_letter: "Xaa".to_string(),
397        },
398    );
399
400    map
401});
402
403// Static data for nucleotides
404static NUCLEOTIDES: Lazy<HashMap<char, NucleotideInfo>> = Lazy::new(|| {
405    let mut map = HashMap::new();
406
407    map.insert(
408        'A',
409        NucleotideInfo {
410            name: "Adenine".to_string(),
411            complement: 'T',
412            is_purine: true,
413            molecular_weight: 331.2,
414        },
415    );
416    map.insert(
417        'T',
418        NucleotideInfo {
419            name: "Thymine".to_string(),
420            complement: 'A',
421            is_purine: false,
422            molecular_weight: 322.2,
423        },
424    );
425    map.insert(
426        'G',
427        NucleotideInfo {
428            name: "Guanine".to_string(),
429            complement: 'C',
430            is_purine: true,
431            molecular_weight: 347.2,
432        },
433    );
434    map.insert(
435        'C',
436        NucleotideInfo {
437            name: "Cytosine".to_string(),
438            complement: 'G',
439            is_purine: false,
440            molecular_weight: 307.2,
441        },
442    );
443    map.insert(
444        'U',
445        NucleotideInfo {
446            name: "Uracil".to_string(),
447            complement: 'A',
448            is_purine: false,
449            molecular_weight: 308.2,
450        },
451    );
452
453    // Ambiguous nucleotides
454    map.insert(
455        'N',
456        NucleotideInfo {
457            name: "Any nucleotide".to_string(),
458            complement: 'N',
459            is_purine: false,
460            molecular_weight: 0.0,
461        },
462    );
463    map.insert(
464        'R',
465        NucleotideInfo {
466            name: "Purine".to_string(),
467            complement: 'Y',
468            is_purine: true,
469            molecular_weight: 0.0,
470        },
471    );
472    map.insert(
473        'Y',
474        NucleotideInfo {
475            name: "Pyrimidine".to_string(),
476            complement: 'R',
477            is_purine: false,
478            molecular_weight: 0.0,
479        },
480    );
481
482    map
483});
484
485// Standard genetic code
486static GENETIC_CODE: Lazy<HashMap<String, char>> = Lazy::new(|| {
487    let mut map = HashMap::new();
488
489    // Standard genetic code table
490    map.insert("TTT".to_string(), 'F');
491    map.insert("TTC".to_string(), 'F');
492    map.insert("TTA".to_string(), 'L');
493    map.insert("TTG".to_string(), 'L');
494    map.insert("TCT".to_string(), 'S');
495    map.insert("TCC".to_string(), 'S');
496    map.insert("TCA".to_string(), 'S');
497    map.insert("TCG".to_string(), 'S');
498    map.insert("TAT".to_string(), 'Y');
499    map.insert("TAC".to_string(), 'Y');
500    map.insert("TAA".to_string(), '*');
501    map.insert("TAG".to_string(), '*'); // Stop codons
502    map.insert("TGT".to_string(), 'C');
503    map.insert("TGC".to_string(), 'C');
504    map.insert("TGA".to_string(), '*');
505    map.insert("TGG".to_string(), 'W'); // Stop codon
506
507    map.insert("CTT".to_string(), 'L');
508    map.insert("CTC".to_string(), 'L');
509    map.insert("CTA".to_string(), 'L');
510    map.insert("CTG".to_string(), 'L');
511    map.insert("CCT".to_string(), 'P');
512    map.insert("CCC".to_string(), 'P');
513    map.insert("CCA".to_string(), 'P');
514    map.insert("CCG".to_string(), 'P');
515    map.insert("CAT".to_string(), 'H');
516    map.insert("CAC".to_string(), 'H');
517    map.insert("CAA".to_string(), 'Q');
518    map.insert("CAG".to_string(), 'Q');
519    map.insert("CGT".to_string(), 'R');
520    map.insert("CGC".to_string(), 'R');
521    map.insert("CGA".to_string(), 'R');
522    map.insert("CGG".to_string(), 'R');
523
524    map.insert("ATT".to_string(), 'I');
525    map.insert("ATC".to_string(), 'I');
526    map.insert("ATA".to_string(), 'I');
527    map.insert("ATG".to_string(), 'M'); // Start codon
528    map.insert("ACT".to_string(), 'T');
529    map.insert("ACC".to_string(), 'T');
530    map.insert("ACA".to_string(), 'T');
531    map.insert("ACG".to_string(), 'T');
532    map.insert("AAT".to_string(), 'N');
533    map.insert("AAC".to_string(), 'N');
534    map.insert("AAA".to_string(), 'K');
535    map.insert("AAG".to_string(), 'K');
536    map.insert("AGT".to_string(), 'S');
537    map.insert("AGC".to_string(), 'S');
538    map.insert("AGA".to_string(), 'R');
539    map.insert("AGG".to_string(), 'R');
540
541    map.insert("GTT".to_string(), 'V');
542    map.insert("GTC".to_string(), 'V');
543    map.insert("GTA".to_string(), 'V');
544    map.insert("GTG".to_string(), 'V');
545    map.insert("GCT".to_string(), 'A');
546    map.insert("GCC".to_string(), 'A');
547    map.insert("GCA".to_string(), 'A');
548    map.insert("GCG".to_string(), 'A');
549    map.insert("GAT".to_string(), 'D');
550    map.insert("GAC".to_string(), 'D');
551    map.insert("GAA".to_string(), 'E');
552    map.insert("GAG".to_string(), 'E');
553    map.insert("GGT".to_string(), 'G');
554    map.insert("GGC".to_string(), 'G');
555    map.insert("GGA".to_string(), 'G');
556    map.insert("GGG".to_string(), 'G');
557
558    map
559});
560
561impl Default for BioTokenizer {
562    fn default() -> Self {
563        Self::new()
564    }
565}
566
567impl BioTokenizer {
568    /// Create a new biological tokenizer
569    pub fn new() -> Self {
570        Self::with_config(BioTokenizerConfig::default())
571    }
572
573    /// Create with custom configuration
574    pub fn with_config(config: BioTokenizerConfig) -> Self {
575        let mut tokenizer = Self {
576            config,
577            vocab: HashMap::new(),
578            id_to_token: HashMap::new(),
579            next_id: 0,
580            amino_acids: AMINO_ACIDS.clone(),
581            nucleotides: NUCLEOTIDES.clone(),
582            genetic_code: GENETIC_CODE.clone(),
583            structure_patterns: Self::create_structure_patterns(),
584        };
585
586        tokenizer.initialize_vocab();
587        tokenizer
588    }
589
590    /// Initialize vocabulary with biological tokens
591    fn initialize_vocab(&mut self) {
592        // Add special tokens
593        if self.config.include_special_tokens {
594            self.add_token("[CLS]");
595            self.add_token("[SEP]");
596            self.add_token("[PAD]");
597            self.add_token("[UNK]");
598            self.add_token("[MASK]");
599            self.add_token("[START_SEQ]");
600            self.add_token("[END_SEQ]");
601            self.add_token("[START_PROTEIN]");
602            self.add_token("[END_PROTEIN]");
603            self.add_token("[START_DNA]");
604            self.add_token("[END_DNA]");
605            self.add_token("[START_RNA]");
606            self.add_token("[END_RNA]");
607        }
608
609        // Add nucleotides
610        if self.config.tokenize_dna || self.config.tokenize_rna {
611            let nucleotides: Vec<String> = self.nucleotides.keys().map(|c| c.to_string()).collect();
612            for nucleotide in nucleotides {
613                self.add_token(&nucleotide);
614            }
615        }
616
617        // Add amino acids
618        if self.config.tokenize_proteins {
619            let amino_acids: Vec<String> = self.amino_acids.keys().map(|c| c.to_string()).collect();
620            for amino_acid in amino_acids {
621                self.add_token(&amino_acid);
622            }
623        }
624
625        // Add k-mers if specified
626        if let Some(k) = self.config.kmer_size {
627            self.generate_kmers(k);
628        }
629
630        // Add stop/start codons
631        self.add_token("ATG"); // Start codon
632        self.add_token("TAA"); // Stop codon
633        self.add_token("TAG"); // Stop codon
634        self.add_token("TGA"); // Stop codon
635
636        // Add secondary structure elements if enabled
637        if self.config.tokenize_structure {
638            self.add_token("H"); // Helix
639            self.add_token("E"); // Beta sheet
640            self.add_token("C"); // Coil/loop
641            self.add_token("T"); // Turn
642        }
643    }
644
645    /// Generate k-mer vocabulary
646    fn generate_kmers(&mut self, k: usize) {
647        if self.config.tokenize_dna || self.config.tokenize_rna {
648            let nucleotides = if self.config.tokenize_rna { "AUGC" } else { "ATGC" };
649            self.generate_kmer_combinations(nucleotides.chars().collect(), k, String::new());
650        }
651
652        if self.config.tokenize_proteins {
653            let amino_acids: Vec<char> = self.amino_acids.keys().copied().collect();
654            if k <= 3 {
655                // Only generate small protein k-mers to avoid explosion
656                self.generate_kmer_combinations(amino_acids, k, String::new());
657            }
658        }
659    }
660
661    /// Recursively generate k-mer combinations
662    fn generate_kmer_combinations(&mut self, alphabet: Vec<char>, k: usize, current: String) {
663        if current.len() == k {
664            self.add_token(&current);
665            return;
666        }
667
668        if self.vocab.len() >= self.config.vocab_size.unwrap_or(5000) {
669            return; // Stop if vocabulary size limit reached
670        }
671
672        for &c in &alphabet {
673            let mut next = current.clone();
674            next.push(c);
675            self.generate_kmer_combinations(alphabet.clone(), k, next);
676        }
677    }
678
679    /// Add token to vocabulary
680    fn add_token(&mut self, token: &str) -> u32 {
681        if let Some(&id) = self.vocab.get(token) {
682            return id;
683        }
684
685        let id = self.next_id;
686        self.vocab.insert(token.to_string(), id);
687        self.id_to_token.insert(id, token.to_string());
688        self.next_id += 1;
689        id
690    }
691
692    /// Create secondary structure patterns
693    fn create_structure_patterns() -> Vec<Regex> {
694        vec![
695            Regex::new(r"[HEC]+").expect("valid regex"), // Secondary structure annotations
696            Regex::new(r"[αβ]+").expect("valid regex"),  // Greek letter annotations
697        ]
698    }
699
700    /// Tokenize biological sequence
701    pub fn tokenize_bio(&self, sequence: &str) -> Result<Vec<BioToken>> {
702        let sequence = if self.config.preserve_case {
703            sequence.to_string()
704        } else {
705            sequence.to_uppercase()
706        };
707
708        // Detect sequence type
709        let seq_type = self.detect_sequence_type(&sequence);
710
711        match seq_type {
712            SequenceType::DNA => self.tokenize_dna(&sequence),
713            SequenceType::RNA => self.tokenize_rna(&sequence),
714            SequenceType::Protein => self.tokenize_protein(&sequence),
715            SequenceType::Structure => self.tokenize_structure(&sequence),
716            SequenceType::Unknown => self.tokenize_fallback(&sequence),
717        }
718    }
719
720    /// Detect sequence type
721    fn detect_sequence_type(&self, sequence: &str) -> SequenceType {
722        let chars: Vec<char> = sequence.chars().collect();
723        let total = chars.len() as f64;
724
725        if total == 0.0 {
726            return SequenceType::Unknown;
727        }
728
729        // Count different character types
730        let dna_chars = chars.iter().filter(|&&c| "ATGC".contains(c)).count() as f64 / total;
731        let rna_chars = chars.iter().filter(|&&c| "AUGC".contains(c)).count() as f64 / total;
732        let protein_chars =
733            chars.iter().filter(|&&c| self.amino_acids.contains_key(&c)).count() as f64 / total;
734        let structure_chars = chars.iter().filter(|&&c| "HEC".contains(c)).count() as f64 / total;
735
736        // Determine most likely sequence type
737        if dna_chars > 0.8 && !sequence.contains('U') {
738            SequenceType::DNA
739        } else if rna_chars > 0.8 && sequence.contains('U') {
740            SequenceType::RNA
741        } else if protein_chars > 0.8 {
742            SequenceType::Protein
743        } else if structure_chars > 0.5 {
744            SequenceType::Structure
745        } else {
746            SequenceType::Unknown
747        }
748    }
749
750    /// Tokenize DNA sequence
751    fn tokenize_dna(&self, sequence: &str) -> Result<Vec<BioToken>> {
752        let mut tokens = Vec::new();
753
754        if let Some(k) = self.config.kmer_size {
755            // K-mer tokenization
756            tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::DNANucleotide)?);
757        } else {
758            // Single nucleotide tokenization
759            for (i, c) in sequence.char_indices() {
760                let token_type = if self.nucleotides.contains_key(&c) {
761                    if "ATGC".contains(c) {
762                        BioTokenType::DNANucleotide
763                    } else {
764                        BioTokenType::AmbiguousNucleotide
765                    }
766                } else {
767                    BioTokenType::Unknown
768                };
769
770                let metadata = self.create_nucleotide_metadata(c);
771
772                tokens.push(BioToken {
773                    text: c.to_string(),
774                    token_type,
775                    start: i,
776                    end: i + 1,
777                    metadata,
778                });
779            }
780        }
781
782        Ok(tokens)
783    }
784
785    /// Tokenize RNA sequence
786    fn tokenize_rna(&self, sequence: &str) -> Result<Vec<BioToken>> {
787        let mut tokens = Vec::new();
788
789        if let Some(k) = self.config.kmer_size {
790            tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::RNANucleotide)?);
791        } else {
792            for (i, c) in sequence.char_indices() {
793                let token_type = if self.nucleotides.contains_key(&c) {
794                    if "AUGC".contains(c) {
795                        BioTokenType::RNANucleotide
796                    } else {
797                        BioTokenType::AmbiguousNucleotide
798                    }
799                } else {
800                    BioTokenType::Unknown
801                };
802
803                let metadata = self.create_nucleotide_metadata(c);
804
805                tokens.push(BioToken {
806                    text: c.to_string(),
807                    token_type,
808                    start: i,
809                    end: i + 1,
810                    metadata,
811                });
812            }
813        }
814
815        Ok(tokens)
816    }
817
818    /// Tokenize protein sequence
819    fn tokenize_protein(&self, sequence: &str) -> Result<Vec<BioToken>> {
820        let mut tokens = Vec::new();
821
822        if let Some(k) = self.config.kmer_size {
823            tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::AminoAcid)?);
824        } else {
825            for (i, c) in sequence.char_indices() {
826                let token_type = if self.amino_acids.contains_key(&c) {
827                    if "ACDEFGHIKLMNPQRSTVWY".contains(c) {
828                        BioTokenType::AminoAcid
829                    } else {
830                        BioTokenType::AmbiguousAminoAcid
831                    }
832                } else {
833                    BioTokenType::Unknown
834                };
835
836                let metadata = self.create_amino_acid_metadata(c);
837
838                tokens.push(BioToken {
839                    text: c.to_string(),
840                    token_type,
841                    start: i,
842                    end: i + 1,
843                    metadata,
844                });
845            }
846        }
847
848        Ok(tokens)
849    }
850
851    /// Tokenize secondary structure
852    fn tokenize_structure(&self, sequence: &str) -> Result<Vec<BioToken>> {
853        let mut tokens = Vec::new();
854
855        for (i, c) in sequence.char_indices() {
856            let token_type = BioTokenType::SecondaryStructure;
857            let structure_type = match c {
858                'H' => Some("Helix".to_string()),
859                'E' => Some("Beta sheet".to_string()),
860                'C' => Some("Coil".to_string()),
861                'T' => Some("Turn".to_string()),
862                _ => Some("Unknown".to_string()),
863            };
864
865            let metadata = BioTokenMetadata {
866                structure_type,
867                ..Default::default()
868            };
869
870            tokens.push(BioToken {
871                text: c.to_string(),
872                token_type,
873                start: i,
874                end: i + 1,
875                metadata: Some(metadata),
876            });
877        }
878
879        Ok(tokens)
880    }
881
882    /// Tokenize using k-mers
883    fn tokenize_kmers(
884        &self,
885        sequence: &str,
886        k: usize,
887        _base_type: BioTokenType,
888    ) -> Result<Vec<BioToken>> {
889        let mut tokens = Vec::new();
890        let chars: Vec<char> = sequence.chars().collect();
891
892        if chars.len() < k {
893            return Ok(tokens);
894        }
895
896        let step = if self.config.overlapping_kmers { 1 } else { k };
897
898        for i in (0..=chars.len() - k).step_by(step) {
899            let kmer: String = chars[i..i + k].iter().collect();
900
901            let token_type = if kmer.len() == 3 && self.genetic_code.contains_key(&kmer) {
902                if self.genetic_code[&kmer] == '*' {
903                    BioTokenType::StopCodon
904                } else if kmer == "ATG" {
905                    BioTokenType::StartCodon
906                } else {
907                    BioTokenType::Kmer
908                }
909            } else {
910                BioTokenType::Kmer
911            };
912
913            let metadata = self.create_kmer_metadata(&kmer, i);
914
915            tokens.push(BioToken {
916                text: kmer,
917                token_type,
918                start: i,
919                end: i + k,
920                metadata,
921            });
922        }
923
924        Ok(tokens)
925    }
926
927    /// Fallback tokenization
928    fn tokenize_fallback(&self, sequence: &str) -> Result<Vec<BioToken>> {
929        let mut tokens = Vec::new();
930
931        for (i, c) in sequence.char_indices() {
932            tokens.push(BioToken {
933                text: c.to_string(),
934                token_type: BioTokenType::Unknown,
935                start: i,
936                end: i + 1,
937                metadata: None,
938            });
939        }
940
941        Ok(tokens)
942    }
943
944    /// Create nucleotide metadata
945    fn create_nucleotide_metadata(&self, nucleotide: char) -> Option<BioTokenMetadata> {
946        self.nucleotides.get(&nucleotide).map(|info| BioTokenMetadata {
947            molecular_weight: Some(info.molecular_weight),
948            hydrophobicity: None,
949            charge: None,
950            gc_content: if "GC".contains(nucleotide) { Some(1.0) } else { Some(0.0) },
951            melting_temp: None,
952            codon_position: None,
953            reading_frame: None,
954            structure_type: None,
955        })
956    }
957
958    /// Create amino acid metadata
959    fn create_amino_acid_metadata(&self, amino_acid: char) -> Option<BioTokenMetadata> {
960        self.amino_acids.get(&amino_acid).map(|info| BioTokenMetadata {
961            molecular_weight: Some(info.molecular_weight),
962            hydrophobicity: Some(info.hydrophobicity),
963            charge: Some(info.charge),
964            gc_content: None,
965            melting_temp: None,
966            codon_position: None,
967            reading_frame: None,
968            structure_type: None,
969        })
970    }
971
972    /// Create k-mer metadata
973    fn create_kmer_metadata(&self, kmer: &str, position: usize) -> Option<BioTokenMetadata> {
974        let mut metadata = BioTokenMetadata::default();
975
976        // Calculate GC content for DNA/RNA k-mers
977        if kmer.chars().all(|c| "ATGCU".contains(c)) {
978            let gc_count = kmer.chars().filter(|&c| "GC".contains(c)).count();
979            metadata.gc_content = Some(gc_count as f64 / kmer.len() as f64);
980        }
981
982        // Set reading frame for codons
983        if kmer.len() == 3 {
984            metadata.reading_frame = Some((position % 3) as u8);
985            if let Some(&amino_acid) = self.genetic_code.get(kmer) {
986                if let Some(aa_info) = self.amino_acids.get(&amino_acid) {
987                    metadata.molecular_weight = Some(aa_info.molecular_weight);
988                    metadata.hydrophobicity = Some(aa_info.hydrophobicity);
989                    metadata.charge = Some(aa_info.charge);
990                }
991            }
992        }
993
994        Some(metadata)
995    }
996
997    /// Get vocabulary
998    pub fn get_vocab(&self) -> &HashMap<String, u32> {
999        &self.vocab
1000    }
1001
1002    /// Get token by ID
1003    pub fn id_to_token(&self, id: u32) -> Option<&String> {
1004        self.id_to_token.get(&id)
1005    }
1006
1007    /// Get configuration
1008    pub fn config(&self) -> &BioTokenizerConfig {
1009        &self.config
1010    }
1011
1012    /// Translate DNA to protein
1013    pub fn translate_dna(&self, dna_sequence: &str) -> Result<String> {
1014        let dna = dna_sequence.to_uppercase();
1015        let mut protein = String::new();
1016
1017        for i in (0..dna.len()).step_by(3) {
1018            if i + 3 <= dna.len() {
1019                let codon = &dna[i..i + 3];
1020                if let Some(&amino_acid) = self.genetic_code.get(codon) {
1021                    if amino_acid == '*' {
1022                        break; // Stop at stop codon
1023                    }
1024                    protein.push(amino_acid);
1025                } else {
1026                    protein.push('X'); // Unknown amino acid
1027                }
1028            }
1029        }
1030
1031        Ok(protein)
1032    }
1033
1034    /// Get reverse complement of DNA sequence
1035    pub fn reverse_complement(&self, dna_sequence: &str) -> String {
1036        dna_sequence
1037            .chars()
1038            .rev()
1039            .map(|c| {
1040                if let Some(info) = self.nucleotides.get(&c.to_ascii_uppercase()) {
1041                    info.complement
1042                } else {
1043                    'N'
1044                }
1045            })
1046            .collect()
1047    }
1048}
1049
1050/// Sequence type detection
1051#[derive(Debug, Clone, PartialEq)]
1052#[allow(clippy::upper_case_acronyms)]
1053enum SequenceType {
1054    DNA,
1055    RNA,
1056    Protein,
1057    Structure,
1058    Unknown,
1059}
1060
1061impl Tokenizer for BioTokenizer {
1062    fn encode(&self, text: &str) -> Result<TokenizedInput> {
1063        let bio_tokens = self.tokenize_bio(text)?;
1064        let mut input_ids = Vec::new();
1065
1066        for token in bio_tokens {
1067            if let Some(&id) = self.vocab.get(&token.text) {
1068                input_ids.push(id);
1069            } else {
1070                // Use UNK token
1071                if let Some(&unk_id) = self.vocab.get("[UNK]") {
1072                    input_ids.push(unk_id);
1073                } else {
1074                    input_ids.push(0); // Fallback
1075                }
1076            }
1077        }
1078
1079        // Apply max length constraint
1080        if let Some(max_len) = self.config.max_length {
1081            input_ids.truncate(max_len);
1082        }
1083
1084        let input_len = input_ids.len();
1085        Ok(TokenizedInput {
1086            input_ids,
1087            attention_mask: vec![1; input_len],
1088            token_type_ids: None,
1089            special_tokens_mask: None,
1090            offset_mapping: None,
1091            overflowing_tokens: None,
1092        })
1093    }
1094
1095    fn decode(&self, token_ids: &[u32]) -> Result<String> {
1096        let mut result = String::new();
1097
1098        for &id in token_ids {
1099            if let Some(token) = self.id_to_token.get(&id) {
1100                if !token.starts_with('[') || !token.ends_with(']') {
1101                    result.push_str(token);
1102                }
1103            }
1104        }
1105
1106        Ok(result)
1107    }
1108
1109    fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1110        let mut tokenized_a = self.encode(text_a)?;
1111        let tokenized_b = self.encode(text_b)?;
1112
1113        // Add separator token if available
1114        if let Some(&sep_id) = self.vocab.get("[SEP]") {
1115            tokenized_a.input_ids.push(sep_id);
1116        }
1117
1118        tokenized_a.input_ids.extend(tokenized_b.input_ids);
1119
1120        // Apply max length constraint
1121        if let Some(max_len) = self.config.max_length {
1122            tokenized_a.input_ids.truncate(max_len);
1123        }
1124
1125        Ok(tokenized_a)
1126    }
1127
1128    fn vocab_size(&self) -> usize {
1129        self.vocab.len()
1130    }
1131
1132    fn get_vocab(&self) -> HashMap<String, u32> {
1133        self.vocab.clone()
1134    }
1135
1136    fn token_to_id(&self, token: &str) -> Option<u32> {
1137        self.vocab.get(token).copied()
1138    }
1139
1140    fn id_to_token(&self, id: u32) -> Option<String> {
1141        self.id_to_token.get(&id).cloned()
1142    }
1143}
1144
1145/// Biological sequence analysis
1146pub struct BioAnalysis {
1147    /// Token type distribution
1148    pub token_types: HashMap<BioTokenType, usize>,
1149    /// Amino acid composition (for proteins)
1150    pub amino_acid_composition: HashMap<char, usize>,
1151    /// Nucleotide composition (for DNA/RNA)
1152    pub nucleotide_composition: HashMap<char, usize>,
1153    /// GC content (for DNA/RNA)
1154    pub gc_content: Option<f64>,
1155    /// Molecular weight (for proteins)
1156    pub molecular_weight: Option<f64>,
1157    /// Hydrophobicity (for proteins)
1158    pub avg_hydrophobicity: Option<f64>,
1159    /// Charge (for proteins)
1160    pub net_charge: Option<i32>,
1161    /// K-mer diversity
1162    pub kmer_diversity: f64,
1163    /// Sequence length
1164    pub sequence_length: usize,
1165}
1166
1167impl BioTokenizer {
1168    /// Analyze biological sequence
1169    pub fn analyze(&self, sequence: &str) -> Result<BioAnalysis> {
1170        let tokens = self.tokenize_bio(sequence)?;
1171
1172        let mut token_types = HashMap::new();
1173        let mut amino_acid_composition = HashMap::new();
1174        let mut nucleotide_composition = HashMap::new();
1175        let mut molecular_weight = 0.0;
1176        let mut total_hydrophobicity = 0.0;
1177        let mut net_charge = 0i32;
1178        let mut gc_count = 0;
1179        let mut nucleotide_count = 0;
1180        let mut protein_residue_count = 0;
1181
1182        for token in &tokens {
1183            *token_types.entry(token.token_type.clone()).or_insert(0) += 1;
1184
1185            if token.text.len() == 1 {
1186                let c = token
1187                    .text
1188                    .chars()
1189                    .next()
1190                    .expect("token.text with len()==1 must have at least one char");
1191
1192                match token.token_type {
1193                    BioTokenType::AminoAcid => {
1194                        *amino_acid_composition.entry(c).or_insert(0) += 1;
1195                        if let Some(info) = self.amino_acids.get(&c) {
1196                            molecular_weight += info.molecular_weight;
1197                            total_hydrophobicity += info.hydrophobicity;
1198                            net_charge += info.charge as i32;
1199                            protein_residue_count += 1;
1200                        }
1201                    },
1202                    BioTokenType::DNANucleotide | BioTokenType::RNANucleotide => {
1203                        *nucleotide_composition.entry(c).or_insert(0) += 1;
1204                        if "GC".contains(c) {
1205                            gc_count += 1;
1206                        }
1207                        nucleotide_count += 1;
1208                    },
1209                    _ => {},
1210                }
1211            } else {
1212                // Handle k-mer tokens for both nucleotide and protein analysis
1213                if token.token_type == BioTokenType::Kmer {
1214                    for c in token.text.chars() {
1215                        // Check if it's a nucleotide sequence
1216                        if "ATGCU".contains(c) {
1217                            *nucleotide_composition.entry(c).or_insert(0) += 1;
1218                            if "GC".contains(c) {
1219                                gc_count += 1;
1220                            }
1221                            nucleotide_count += 1;
1222                        }
1223                        // Check if it's a protein sequence
1224                        else if self.amino_acids.contains_key(&c) {
1225                            *amino_acid_composition.entry(c).or_insert(0) += 1;
1226                            if let Some(info) = self.amino_acids.get(&c) {
1227                                molecular_weight += info.molecular_weight;
1228                                total_hydrophobicity += info.hydrophobicity;
1229                                net_charge += info.charge as i32;
1230                                protein_residue_count += 1;
1231                            }
1232                        }
1233                    }
1234                }
1235            }
1236        }
1237
1238        let gc_content = if nucleotide_count > 0 {
1239            Some(gc_count as f64 / nucleotide_count as f64)
1240        } else {
1241            None
1242        };
1243
1244        let avg_hydrophobicity = if protein_residue_count > 0 {
1245            Some(total_hydrophobicity / protein_residue_count as f64)
1246        } else {
1247            None
1248        };
1249
1250        let molecular_weight_final =
1251            if protein_residue_count > 0 { Some(molecular_weight) } else { None };
1252
1253        let net_charge_final = if protein_residue_count > 0 { Some(net_charge) } else { None };
1254
1255        // Calculate k-mer diversity (Simpson's diversity index)
1256        let total_tokens = tokens.len();
1257        let kmer_diversity = if total_tokens > 0 {
1258            let mut diversity = 0.0;
1259            for count in token_types.values() {
1260                let frequency = *count as f64 / total_tokens as f64;
1261                diversity += frequency * frequency;
1262            }
1263            1.0 - diversity
1264        } else {
1265            0.0
1266        };
1267
1268        Ok(BioAnalysis {
1269            token_types,
1270            amino_acid_composition,
1271            nucleotide_composition,
1272            gc_content,
1273            molecular_weight: molecular_weight_final,
1274            avg_hydrophobicity,
1275            net_charge: net_charge_final,
1276            kmer_diversity,
1277            sequence_length: sequence.len(),
1278        })
1279    }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284    use super::*;
1285
1286    #[test]
1287    fn test_bio_tokenizer_creation() {
1288        let tokenizer = BioTokenizer::new();
1289        assert!(tokenizer.get_vocab().len() > 0);
1290        assert!(tokenizer.get_vocab().contains_key("A"));
1291        assert!(tokenizer.get_vocab().contains_key("T"));
1292    }
1293
1294    #[test]
1295    fn test_sequence_type_detection() {
1296        let tokenizer = BioTokenizer::new();
1297        assert_eq!(
1298            tokenizer.detect_sequence_type("ATGCGATCG"),
1299            SequenceType::DNA
1300        );
1301        assert_eq!(
1302            tokenizer.detect_sequence_type("AUGCGAUCG"),
1303            SequenceType::RNA
1304        );
1305        assert_eq!(
1306            tokenizer.detect_sequence_type("MTKQVFTPG"),
1307            SequenceType::Protein
1308        );
1309    }
1310
1311    #[test]
1312    fn test_dna_encoding() {
1313        let tokenizer = BioTokenizer::new();
1314        let result = tokenizer.encode("ATGCGATCG");
1315        assert!(result.is_ok());
1316        let tokenized = result.expect("Operation failed in test");
1317        assert!(!tokenized.input_ids.is_empty());
1318    }
1319
1320    #[test]
1321    fn test_protein_encoding() {
1322        let tokenizer = BioTokenizer::new();
1323        let result = tokenizer.encode("MTKQVFTPG");
1324        assert!(result.is_ok());
1325        let tokenized = result.expect("Operation failed in test");
1326        assert!(!tokenized.input_ids.is_empty());
1327    }
1328
1329    #[test]
1330    fn test_kmer_tokenization() {
1331        let mut config = BioTokenizerConfig::default();
1332        config.kmer_size = Some(3);
1333        let tokenizer = BioTokenizer::with_config(config);
1334
1335        let tokens = tokenizer.tokenize_bio("ATGCGATCG").expect("Operation failed in test");
1336        assert!(tokens.iter().any(|t| t.text.len() == 3));
1337    }
1338
1339    #[test]
1340    fn test_translation() {
1341        let tokenizer = BioTokenizer::new();
1342        let protein = tokenizer.translate_dna("ATGAAATAG").expect("Operation failed in test");
1343        assert_eq!(protein, "MK"); // ATG=M, AAA=K, TAG=stop
1344    }
1345
1346    #[test]
1347    fn test_reverse_complement() {
1348        let tokenizer = BioTokenizer::new();
1349        let rc = tokenizer.reverse_complement("ATGC");
1350        assert_eq!(rc, "GCAT");
1351    }
1352
1353    #[test]
1354    fn test_amino_acid_metadata() {
1355        let tokenizer = BioTokenizer::new();
1356        let metadata = tokenizer.create_amino_acid_metadata('A');
1357        assert!(metadata.is_some());
1358        let meta = metadata.expect("Operation failed in test");
1359        assert!(meta.molecular_weight.is_some());
1360        assert!(meta.hydrophobicity.is_some());
1361    }
1362
1363    #[test]
1364    fn test_nucleotide_metadata() {
1365        let tokenizer = BioTokenizer::new();
1366        let metadata = tokenizer.create_nucleotide_metadata('G');
1367        assert!(metadata.is_some());
1368        let meta = metadata.expect("Operation failed in test");
1369        assert_eq!(meta.gc_content, Some(1.0));
1370    }
1371
1372    #[test]
1373    fn test_bio_analysis() {
1374        let tokenizer = BioTokenizer::new();
1375        let analysis = tokenizer.analyze("ATGCGATCG");
1376        assert!(analysis.is_ok());
1377        let result = analysis.expect("Operation failed in test");
1378        assert!(result.gc_content.is_some());
1379        assert!(!result.nucleotide_composition.is_empty());
1380    }
1381
1382    #[test]
1383    fn test_protein_analysis() {
1384        let tokenizer = BioTokenizer::new();
1385        let analysis = tokenizer.analyze("MTKQVFTPG");
1386        assert!(analysis.is_ok());
1387        let result = analysis.expect("Operation failed in test");
1388        assert!(result.molecular_weight.is_some());
1389        assert!(result.avg_hydrophobicity.is_some());
1390        assert!(!result.amino_acid_composition.is_empty());
1391    }
1392
1393    #[test]
1394    fn test_stop_codon_detection() {
1395        let tokenizer = BioTokenizer::new();
1396        let tokens = tokenizer.tokenize_bio("ATGTAG").expect("Operation failed in test");
1397        assert!(tokens.iter().any(|t| t.token_type == BioTokenType::StartCodon));
1398        assert!(tokens.iter().any(|t| t.token_type == BioTokenType::StopCodon));
1399    }
1400
1401    #[test]
1402    fn test_max_length_constraint() {
1403        let mut config = BioTokenizerConfig::default();
1404        config.max_length = Some(5);
1405        let tokenizer = BioTokenizer::with_config(config);
1406
1407        let result = tokenizer.encode("ATGCGATCGATCGATCG");
1408        assert!(result.is_ok());
1409        let tokenized = result.expect("Operation failed in test");
1410        assert!(tokenized.input_ids.len() <= 5);
1411    }
1412}