1use once_cell::sync::Lazy;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use trustformers_core::errors::Result;
11use trustformers_core::traits::{TokenizedInput, Tokenizer};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BioTokenizerConfig {
16 pub max_length: Option<usize>,
18 pub include_special_tokens: bool,
20 pub tokenize_dna: bool,
22 pub tokenize_rna: bool,
24 pub tokenize_proteins: bool,
26 pub kmer_size: Option<usize>,
28 pub overlapping_kmers: bool,
30 pub preserve_case: bool,
32 pub handle_ambiguous: bool,
34 pub tokenize_structure: bool,
36 pub vocab_size: Option<usize>,
38}
39
40impl Default for BioTokenizerConfig {
41 fn default() -> Self {
42 Self {
43 max_length: Some(2048),
44 include_special_tokens: true,
45 tokenize_dna: true,
46 tokenize_rna: true,
47 tokenize_proteins: true,
48 kmer_size: Some(3), overlapping_kmers: true,
50 preserve_case: false,
51 handle_ambiguous: true,
52 tokenize_structure: false,
53 vocab_size: Some(5000),
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum BioTokenType {
61 DNANucleotide,
63 RNANucleotide,
65 AminoAcid,
67 Kmer,
69 AmbiguousNucleotide,
71 AmbiguousAminoAcid,
73 StopCodon,
75 StartCodon,
77 SecondaryStructure,
79 Modification,
81 Special,
83 Unknown,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct BioToken {
90 pub text: String,
92 pub token_type: BioTokenType,
94 pub start: usize,
96 pub end: usize,
98 pub metadata: Option<BioTokenMetadata>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, Default)]
104pub struct BioTokenMetadata {
105 pub molecular_weight: Option<f64>,
107 pub hydrophobicity: Option<f64>,
109 pub charge: Option<i8>,
111 pub gc_content: Option<f64>,
113 pub melting_temp: Option<f64>,
115 pub codon_position: Option<u8>,
117 pub reading_frame: Option<u8>,
119 pub structure_type: Option<String>,
121}
122
123pub struct BioTokenizer {
125 config: BioTokenizerConfig,
126 vocab: HashMap<String, u32>,
127 id_to_token: HashMap<u32, String>,
128 next_id: u32,
129 amino_acids: HashMap<char, AminoAcidInfo>,
130 nucleotides: HashMap<char, NucleotideInfo>,
131 genetic_code: HashMap<String, char>,
132 #[allow(dead_code)]
133 structure_patterns: Vec<Regex>,
134}
135
136#[derive(Debug, Clone)]
138struct AminoAcidInfo {
139 #[allow(dead_code)]
140 name: String,
141 molecular_weight: f64,
142 hydrophobicity: f64,
143 charge: i8,
144 #[allow(dead_code)]
145 single_letter: char,
146 #[allow(dead_code)]
147 three_letter: String,
148}
149
150#[derive(Debug, Clone)]
152struct NucleotideInfo {
153 #[allow(dead_code)]
154 name: String,
155 complement: char,
156 #[allow(dead_code)]
157 is_purine: bool,
158 molecular_weight: f64,
159}
160
161static AMINO_ACIDS: Lazy<HashMap<char, AminoAcidInfo>> = Lazy::new(|| {
163 let mut map = HashMap::new();
164
165 map.insert(
167 'A',
168 AminoAcidInfo {
169 name: "Alanine".to_string(),
170 molecular_weight: 89.1,
171 hydrophobicity: 1.8,
172 charge: 0,
173 single_letter: 'A',
174 three_letter: "Ala".to_string(),
175 },
176 );
177 map.insert(
178 'R',
179 AminoAcidInfo {
180 name: "Arginine".to_string(),
181 molecular_weight: 174.2,
182 hydrophobicity: -4.5,
183 charge: 1,
184 single_letter: 'R',
185 three_letter: "Arg".to_string(),
186 },
187 );
188 map.insert(
189 'N',
190 AminoAcidInfo {
191 name: "Asparagine".to_string(),
192 molecular_weight: 132.1,
193 hydrophobicity: -3.5,
194 charge: 0,
195 single_letter: 'N',
196 three_letter: "Asn".to_string(),
197 },
198 );
199 map.insert(
200 'D',
201 AminoAcidInfo {
202 name: "Aspartic acid".to_string(),
203 molecular_weight: 133.1,
204 hydrophobicity: -3.5,
205 charge: -1,
206 single_letter: 'D',
207 three_letter: "Asp".to_string(),
208 },
209 );
210 map.insert(
211 'C',
212 AminoAcidInfo {
213 name: "Cysteine".to_string(),
214 molecular_weight: 121.0,
215 hydrophobicity: 2.5,
216 charge: 0,
217 single_letter: 'C',
218 three_letter: "Cys".to_string(),
219 },
220 );
221 map.insert(
222 'E',
223 AminoAcidInfo {
224 name: "Glutamic acid".to_string(),
225 molecular_weight: 147.1,
226 hydrophobicity: -3.5,
227 charge: -1,
228 single_letter: 'E',
229 three_letter: "Glu".to_string(),
230 },
231 );
232 map.insert(
233 'Q',
234 AminoAcidInfo {
235 name: "Glutamine".to_string(),
236 molecular_weight: 146.1,
237 hydrophobicity: -3.5,
238 charge: 0,
239 single_letter: 'Q',
240 three_letter: "Gln".to_string(),
241 },
242 );
243 map.insert(
244 'G',
245 AminoAcidInfo {
246 name: "Glycine".to_string(),
247 molecular_weight: 75.1,
248 hydrophobicity: -0.4,
249 charge: 0,
250 single_letter: 'G',
251 three_letter: "Gly".to_string(),
252 },
253 );
254 map.insert(
255 'H',
256 AminoAcidInfo {
257 name: "Histidine".to_string(),
258 molecular_weight: 155.2,
259 hydrophobicity: -3.2,
260 charge: 0,
261 single_letter: 'H',
262 three_letter: "His".to_string(),
263 },
264 );
265 map.insert(
266 'I',
267 AminoAcidInfo {
268 name: "Isoleucine".to_string(),
269 molecular_weight: 131.2,
270 hydrophobicity: 4.5,
271 charge: 0,
272 single_letter: 'I',
273 three_letter: "Ile".to_string(),
274 },
275 );
276 map.insert(
277 'L',
278 AminoAcidInfo {
279 name: "Leucine".to_string(),
280 molecular_weight: 131.2,
281 hydrophobicity: 3.8,
282 charge: 0,
283 single_letter: 'L',
284 three_letter: "Leu".to_string(),
285 },
286 );
287 map.insert(
288 'K',
289 AminoAcidInfo {
290 name: "Lysine".to_string(),
291 molecular_weight: 146.2,
292 hydrophobicity: -3.9,
293 charge: 1,
294 single_letter: 'K',
295 three_letter: "Lys".to_string(),
296 },
297 );
298 map.insert(
299 'M',
300 AminoAcidInfo {
301 name: "Methionine".to_string(),
302 molecular_weight: 149.2,
303 hydrophobicity: 1.9,
304 charge: 0,
305 single_letter: 'M',
306 three_letter: "Met".to_string(),
307 },
308 );
309 map.insert(
310 'F',
311 AminoAcidInfo {
312 name: "Phenylalanine".to_string(),
313 molecular_weight: 165.2,
314 hydrophobicity: 2.8,
315 charge: 0,
316 single_letter: 'F',
317 three_letter: "Phe".to_string(),
318 },
319 );
320 map.insert(
321 'P',
322 AminoAcidInfo {
323 name: "Proline".to_string(),
324 molecular_weight: 115.1,
325 hydrophobicity: -1.6,
326 charge: 0,
327 single_letter: 'P',
328 three_letter: "Pro".to_string(),
329 },
330 );
331 map.insert(
332 'S',
333 AminoAcidInfo {
334 name: "Serine".to_string(),
335 molecular_weight: 105.1,
336 hydrophobicity: -0.8,
337 charge: 0,
338 single_letter: 'S',
339 three_letter: "Ser".to_string(),
340 },
341 );
342 map.insert(
343 'T',
344 AminoAcidInfo {
345 name: "Threonine".to_string(),
346 molecular_weight: 119.1,
347 hydrophobicity: -0.7,
348 charge: 0,
349 single_letter: 'T',
350 three_letter: "Thr".to_string(),
351 },
352 );
353 map.insert(
354 'W',
355 AminoAcidInfo {
356 name: "Tryptophan".to_string(),
357 molecular_weight: 204.2,
358 hydrophobicity: -0.9,
359 charge: 0,
360 single_letter: 'W',
361 three_letter: "Trp".to_string(),
362 },
363 );
364 map.insert(
365 'Y',
366 AminoAcidInfo {
367 name: "Tyrosine".to_string(),
368 molecular_weight: 181.2,
369 hydrophobicity: -1.3,
370 charge: 0,
371 single_letter: 'Y',
372 three_letter: "Tyr".to_string(),
373 },
374 );
375 map.insert(
376 'V',
377 AminoAcidInfo {
378 name: "Valine".to_string(),
379 molecular_weight: 117.1,
380 hydrophobicity: 4.2,
381 charge: 0,
382 single_letter: 'V',
383 three_letter: "Val".to_string(),
384 },
385 );
386
387 map.insert(
389 'X',
390 AminoAcidInfo {
391 name: "Unknown".to_string(),
392 molecular_weight: 0.0,
393 hydrophobicity: 0.0,
394 charge: 0,
395 single_letter: 'X',
396 three_letter: "Xaa".to_string(),
397 },
398 );
399
400 map
401});
402
403static NUCLEOTIDES: Lazy<HashMap<char, NucleotideInfo>> = Lazy::new(|| {
405 let mut map = HashMap::new();
406
407 map.insert(
408 'A',
409 NucleotideInfo {
410 name: "Adenine".to_string(),
411 complement: 'T',
412 is_purine: true,
413 molecular_weight: 331.2,
414 },
415 );
416 map.insert(
417 'T',
418 NucleotideInfo {
419 name: "Thymine".to_string(),
420 complement: 'A',
421 is_purine: false,
422 molecular_weight: 322.2,
423 },
424 );
425 map.insert(
426 'G',
427 NucleotideInfo {
428 name: "Guanine".to_string(),
429 complement: 'C',
430 is_purine: true,
431 molecular_weight: 347.2,
432 },
433 );
434 map.insert(
435 'C',
436 NucleotideInfo {
437 name: "Cytosine".to_string(),
438 complement: 'G',
439 is_purine: false,
440 molecular_weight: 307.2,
441 },
442 );
443 map.insert(
444 'U',
445 NucleotideInfo {
446 name: "Uracil".to_string(),
447 complement: 'A',
448 is_purine: false,
449 molecular_weight: 308.2,
450 },
451 );
452
453 map.insert(
455 'N',
456 NucleotideInfo {
457 name: "Any nucleotide".to_string(),
458 complement: 'N',
459 is_purine: false,
460 molecular_weight: 0.0,
461 },
462 );
463 map.insert(
464 'R',
465 NucleotideInfo {
466 name: "Purine".to_string(),
467 complement: 'Y',
468 is_purine: true,
469 molecular_weight: 0.0,
470 },
471 );
472 map.insert(
473 'Y',
474 NucleotideInfo {
475 name: "Pyrimidine".to_string(),
476 complement: 'R',
477 is_purine: false,
478 molecular_weight: 0.0,
479 },
480 );
481
482 map
483});
484
485static GENETIC_CODE: Lazy<HashMap<String, char>> = Lazy::new(|| {
487 let mut map = HashMap::new();
488
489 map.insert("TTT".to_string(), 'F');
491 map.insert("TTC".to_string(), 'F');
492 map.insert("TTA".to_string(), 'L');
493 map.insert("TTG".to_string(), 'L');
494 map.insert("TCT".to_string(), 'S');
495 map.insert("TCC".to_string(), 'S');
496 map.insert("TCA".to_string(), 'S');
497 map.insert("TCG".to_string(), 'S');
498 map.insert("TAT".to_string(), 'Y');
499 map.insert("TAC".to_string(), 'Y');
500 map.insert("TAA".to_string(), '*');
501 map.insert("TAG".to_string(), '*'); map.insert("TGT".to_string(), 'C');
503 map.insert("TGC".to_string(), 'C');
504 map.insert("TGA".to_string(), '*');
505 map.insert("TGG".to_string(), 'W'); map.insert("CTT".to_string(), 'L');
508 map.insert("CTC".to_string(), 'L');
509 map.insert("CTA".to_string(), 'L');
510 map.insert("CTG".to_string(), 'L');
511 map.insert("CCT".to_string(), 'P');
512 map.insert("CCC".to_string(), 'P');
513 map.insert("CCA".to_string(), 'P');
514 map.insert("CCG".to_string(), 'P');
515 map.insert("CAT".to_string(), 'H');
516 map.insert("CAC".to_string(), 'H');
517 map.insert("CAA".to_string(), 'Q');
518 map.insert("CAG".to_string(), 'Q');
519 map.insert("CGT".to_string(), 'R');
520 map.insert("CGC".to_string(), 'R');
521 map.insert("CGA".to_string(), 'R');
522 map.insert("CGG".to_string(), 'R');
523
524 map.insert("ATT".to_string(), 'I');
525 map.insert("ATC".to_string(), 'I');
526 map.insert("ATA".to_string(), 'I');
527 map.insert("ATG".to_string(), 'M'); map.insert("ACT".to_string(), 'T');
529 map.insert("ACC".to_string(), 'T');
530 map.insert("ACA".to_string(), 'T');
531 map.insert("ACG".to_string(), 'T');
532 map.insert("AAT".to_string(), 'N');
533 map.insert("AAC".to_string(), 'N');
534 map.insert("AAA".to_string(), 'K');
535 map.insert("AAG".to_string(), 'K');
536 map.insert("AGT".to_string(), 'S');
537 map.insert("AGC".to_string(), 'S');
538 map.insert("AGA".to_string(), 'R');
539 map.insert("AGG".to_string(), 'R');
540
541 map.insert("GTT".to_string(), 'V');
542 map.insert("GTC".to_string(), 'V');
543 map.insert("GTA".to_string(), 'V');
544 map.insert("GTG".to_string(), 'V');
545 map.insert("GCT".to_string(), 'A');
546 map.insert("GCC".to_string(), 'A');
547 map.insert("GCA".to_string(), 'A');
548 map.insert("GCG".to_string(), 'A');
549 map.insert("GAT".to_string(), 'D');
550 map.insert("GAC".to_string(), 'D');
551 map.insert("GAA".to_string(), 'E');
552 map.insert("GAG".to_string(), 'E');
553 map.insert("GGT".to_string(), 'G');
554 map.insert("GGC".to_string(), 'G');
555 map.insert("GGA".to_string(), 'G');
556 map.insert("GGG".to_string(), 'G');
557
558 map
559});
560
561impl Default for BioTokenizer {
562 fn default() -> Self {
563 Self::new()
564 }
565}
566
567impl BioTokenizer {
568 pub fn new() -> Self {
570 Self::with_config(BioTokenizerConfig::default())
571 }
572
573 pub fn with_config(config: BioTokenizerConfig) -> Self {
575 let mut tokenizer = Self {
576 config,
577 vocab: HashMap::new(),
578 id_to_token: HashMap::new(),
579 next_id: 0,
580 amino_acids: AMINO_ACIDS.clone(),
581 nucleotides: NUCLEOTIDES.clone(),
582 genetic_code: GENETIC_CODE.clone(),
583 structure_patterns: Self::create_structure_patterns(),
584 };
585
586 tokenizer.initialize_vocab();
587 tokenizer
588 }
589
590 fn initialize_vocab(&mut self) {
592 if self.config.include_special_tokens {
594 self.add_token("[CLS]");
595 self.add_token("[SEP]");
596 self.add_token("[PAD]");
597 self.add_token("[UNK]");
598 self.add_token("[MASK]");
599 self.add_token("[START_SEQ]");
600 self.add_token("[END_SEQ]");
601 self.add_token("[START_PROTEIN]");
602 self.add_token("[END_PROTEIN]");
603 self.add_token("[START_DNA]");
604 self.add_token("[END_DNA]");
605 self.add_token("[START_RNA]");
606 self.add_token("[END_RNA]");
607 }
608
609 if self.config.tokenize_dna || self.config.tokenize_rna {
611 let nucleotides: Vec<String> = self.nucleotides.keys().map(|c| c.to_string()).collect();
612 for nucleotide in nucleotides {
613 self.add_token(&nucleotide);
614 }
615 }
616
617 if self.config.tokenize_proteins {
619 let amino_acids: Vec<String> = self.amino_acids.keys().map(|c| c.to_string()).collect();
620 for amino_acid in amino_acids {
621 self.add_token(&amino_acid);
622 }
623 }
624
625 if let Some(k) = self.config.kmer_size {
627 self.generate_kmers(k);
628 }
629
630 self.add_token("ATG"); self.add_token("TAA"); self.add_token("TAG"); self.add_token("TGA"); if self.config.tokenize_structure {
638 self.add_token("H"); self.add_token("E"); self.add_token("C"); self.add_token("T"); }
643 }
644
645 fn generate_kmers(&mut self, k: usize) {
647 if self.config.tokenize_dna || self.config.tokenize_rna {
648 let nucleotides = if self.config.tokenize_rna { "AUGC" } else { "ATGC" };
649 self.generate_kmer_combinations(nucleotides.chars().collect(), k, String::new());
650 }
651
652 if self.config.tokenize_proteins {
653 let amino_acids: Vec<char> = self.amino_acids.keys().copied().collect();
654 if k <= 3 {
655 self.generate_kmer_combinations(amino_acids, k, String::new());
657 }
658 }
659 }
660
661 fn generate_kmer_combinations(&mut self, alphabet: Vec<char>, k: usize, current: String) {
663 if current.len() == k {
664 self.add_token(¤t);
665 return;
666 }
667
668 if self.vocab.len() >= self.config.vocab_size.unwrap_or(5000) {
669 return; }
671
672 for &c in &alphabet {
673 let mut next = current.clone();
674 next.push(c);
675 self.generate_kmer_combinations(alphabet.clone(), k, next);
676 }
677 }
678
679 fn add_token(&mut self, token: &str) -> u32 {
681 if let Some(&id) = self.vocab.get(token) {
682 return id;
683 }
684
685 let id = self.next_id;
686 self.vocab.insert(token.to_string(), id);
687 self.id_to_token.insert(id, token.to_string());
688 self.next_id += 1;
689 id
690 }
691
692 fn create_structure_patterns() -> Vec<Regex> {
694 vec![
695 Regex::new(r"[HEC]+").expect("valid regex"), Regex::new(r"[αβ]+").expect("valid regex"), ]
698 }
699
700 pub fn tokenize_bio(&self, sequence: &str) -> Result<Vec<BioToken>> {
702 let sequence = if self.config.preserve_case {
703 sequence.to_string()
704 } else {
705 sequence.to_uppercase()
706 };
707
708 let seq_type = self.detect_sequence_type(&sequence);
710
711 match seq_type {
712 SequenceType::DNA => self.tokenize_dna(&sequence),
713 SequenceType::RNA => self.tokenize_rna(&sequence),
714 SequenceType::Protein => self.tokenize_protein(&sequence),
715 SequenceType::Structure => self.tokenize_structure(&sequence),
716 SequenceType::Unknown => self.tokenize_fallback(&sequence),
717 }
718 }
719
720 fn detect_sequence_type(&self, sequence: &str) -> SequenceType {
722 let chars: Vec<char> = sequence.chars().collect();
723 let total = chars.len() as f64;
724
725 if total == 0.0 {
726 return SequenceType::Unknown;
727 }
728
729 let dna_chars = chars.iter().filter(|&&c| "ATGC".contains(c)).count() as f64 / total;
731 let rna_chars = chars.iter().filter(|&&c| "AUGC".contains(c)).count() as f64 / total;
732 let protein_chars =
733 chars.iter().filter(|&&c| self.amino_acids.contains_key(&c)).count() as f64 / total;
734 let structure_chars = chars.iter().filter(|&&c| "HEC".contains(c)).count() as f64 / total;
735
736 if dna_chars > 0.8 && !sequence.contains('U') {
738 SequenceType::DNA
739 } else if rna_chars > 0.8 && sequence.contains('U') {
740 SequenceType::RNA
741 } else if protein_chars > 0.8 {
742 SequenceType::Protein
743 } else if structure_chars > 0.5 {
744 SequenceType::Structure
745 } else {
746 SequenceType::Unknown
747 }
748 }
749
750 fn tokenize_dna(&self, sequence: &str) -> Result<Vec<BioToken>> {
752 let mut tokens = Vec::new();
753
754 if let Some(k) = self.config.kmer_size {
755 tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::DNANucleotide)?);
757 } else {
758 for (i, c) in sequence.char_indices() {
760 let token_type = if self.nucleotides.contains_key(&c) {
761 if "ATGC".contains(c) {
762 BioTokenType::DNANucleotide
763 } else {
764 BioTokenType::AmbiguousNucleotide
765 }
766 } else {
767 BioTokenType::Unknown
768 };
769
770 let metadata = self.create_nucleotide_metadata(c);
771
772 tokens.push(BioToken {
773 text: c.to_string(),
774 token_type,
775 start: i,
776 end: i + 1,
777 metadata,
778 });
779 }
780 }
781
782 Ok(tokens)
783 }
784
785 fn tokenize_rna(&self, sequence: &str) -> Result<Vec<BioToken>> {
787 let mut tokens = Vec::new();
788
789 if let Some(k) = self.config.kmer_size {
790 tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::RNANucleotide)?);
791 } else {
792 for (i, c) in sequence.char_indices() {
793 let token_type = if self.nucleotides.contains_key(&c) {
794 if "AUGC".contains(c) {
795 BioTokenType::RNANucleotide
796 } else {
797 BioTokenType::AmbiguousNucleotide
798 }
799 } else {
800 BioTokenType::Unknown
801 };
802
803 let metadata = self.create_nucleotide_metadata(c);
804
805 tokens.push(BioToken {
806 text: c.to_string(),
807 token_type,
808 start: i,
809 end: i + 1,
810 metadata,
811 });
812 }
813 }
814
815 Ok(tokens)
816 }
817
818 fn tokenize_protein(&self, sequence: &str) -> Result<Vec<BioToken>> {
820 let mut tokens = Vec::new();
821
822 if let Some(k) = self.config.kmer_size {
823 tokens.extend(self.tokenize_kmers(sequence, k, BioTokenType::AminoAcid)?);
824 } else {
825 for (i, c) in sequence.char_indices() {
826 let token_type = if self.amino_acids.contains_key(&c) {
827 if "ACDEFGHIKLMNPQRSTVWY".contains(c) {
828 BioTokenType::AminoAcid
829 } else {
830 BioTokenType::AmbiguousAminoAcid
831 }
832 } else {
833 BioTokenType::Unknown
834 };
835
836 let metadata = self.create_amino_acid_metadata(c);
837
838 tokens.push(BioToken {
839 text: c.to_string(),
840 token_type,
841 start: i,
842 end: i + 1,
843 metadata,
844 });
845 }
846 }
847
848 Ok(tokens)
849 }
850
851 fn tokenize_structure(&self, sequence: &str) -> Result<Vec<BioToken>> {
853 let mut tokens = Vec::new();
854
855 for (i, c) in sequence.char_indices() {
856 let token_type = BioTokenType::SecondaryStructure;
857 let structure_type = match c {
858 'H' => Some("Helix".to_string()),
859 'E' => Some("Beta sheet".to_string()),
860 'C' => Some("Coil".to_string()),
861 'T' => Some("Turn".to_string()),
862 _ => Some("Unknown".to_string()),
863 };
864
865 let metadata = BioTokenMetadata {
866 structure_type,
867 ..Default::default()
868 };
869
870 tokens.push(BioToken {
871 text: c.to_string(),
872 token_type,
873 start: i,
874 end: i + 1,
875 metadata: Some(metadata),
876 });
877 }
878
879 Ok(tokens)
880 }
881
882 fn tokenize_kmers(
884 &self,
885 sequence: &str,
886 k: usize,
887 _base_type: BioTokenType,
888 ) -> Result<Vec<BioToken>> {
889 let mut tokens = Vec::new();
890 let chars: Vec<char> = sequence.chars().collect();
891
892 if chars.len() < k {
893 return Ok(tokens);
894 }
895
896 let step = if self.config.overlapping_kmers { 1 } else { k };
897
898 for i in (0..=chars.len() - k).step_by(step) {
899 let kmer: String = chars[i..i + k].iter().collect();
900
901 let token_type = if kmer.len() == 3 && self.genetic_code.contains_key(&kmer) {
902 if self.genetic_code[&kmer] == '*' {
903 BioTokenType::StopCodon
904 } else if kmer == "ATG" {
905 BioTokenType::StartCodon
906 } else {
907 BioTokenType::Kmer
908 }
909 } else {
910 BioTokenType::Kmer
911 };
912
913 let metadata = self.create_kmer_metadata(&kmer, i);
914
915 tokens.push(BioToken {
916 text: kmer,
917 token_type,
918 start: i,
919 end: i + k,
920 metadata,
921 });
922 }
923
924 Ok(tokens)
925 }
926
927 fn tokenize_fallback(&self, sequence: &str) -> Result<Vec<BioToken>> {
929 let mut tokens = Vec::new();
930
931 for (i, c) in sequence.char_indices() {
932 tokens.push(BioToken {
933 text: c.to_string(),
934 token_type: BioTokenType::Unknown,
935 start: i,
936 end: i + 1,
937 metadata: None,
938 });
939 }
940
941 Ok(tokens)
942 }
943
944 fn create_nucleotide_metadata(&self, nucleotide: char) -> Option<BioTokenMetadata> {
946 self.nucleotides.get(&nucleotide).map(|info| BioTokenMetadata {
947 molecular_weight: Some(info.molecular_weight),
948 hydrophobicity: None,
949 charge: None,
950 gc_content: if "GC".contains(nucleotide) { Some(1.0) } else { Some(0.0) },
951 melting_temp: None,
952 codon_position: None,
953 reading_frame: None,
954 structure_type: None,
955 })
956 }
957
958 fn create_amino_acid_metadata(&self, amino_acid: char) -> Option<BioTokenMetadata> {
960 self.amino_acids.get(&amino_acid).map(|info| BioTokenMetadata {
961 molecular_weight: Some(info.molecular_weight),
962 hydrophobicity: Some(info.hydrophobicity),
963 charge: Some(info.charge),
964 gc_content: None,
965 melting_temp: None,
966 codon_position: None,
967 reading_frame: None,
968 structure_type: None,
969 })
970 }
971
972 fn create_kmer_metadata(&self, kmer: &str, position: usize) -> Option<BioTokenMetadata> {
974 let mut metadata = BioTokenMetadata::default();
975
976 if kmer.chars().all(|c| "ATGCU".contains(c)) {
978 let gc_count = kmer.chars().filter(|&c| "GC".contains(c)).count();
979 metadata.gc_content = Some(gc_count as f64 / kmer.len() as f64);
980 }
981
982 if kmer.len() == 3 {
984 metadata.reading_frame = Some((position % 3) as u8);
985 if let Some(&amino_acid) = self.genetic_code.get(kmer) {
986 if let Some(aa_info) = self.amino_acids.get(&amino_acid) {
987 metadata.molecular_weight = Some(aa_info.molecular_weight);
988 metadata.hydrophobicity = Some(aa_info.hydrophobicity);
989 metadata.charge = Some(aa_info.charge);
990 }
991 }
992 }
993
994 Some(metadata)
995 }
996
997 pub fn get_vocab(&self) -> &HashMap<String, u32> {
999 &self.vocab
1000 }
1001
1002 pub fn id_to_token(&self, id: u32) -> Option<&String> {
1004 self.id_to_token.get(&id)
1005 }
1006
1007 pub fn config(&self) -> &BioTokenizerConfig {
1009 &self.config
1010 }
1011
1012 pub fn translate_dna(&self, dna_sequence: &str) -> Result<String> {
1014 let dna = dna_sequence.to_uppercase();
1015 let mut protein = String::new();
1016
1017 for i in (0..dna.len()).step_by(3) {
1018 if i + 3 <= dna.len() {
1019 let codon = &dna[i..i + 3];
1020 if let Some(&amino_acid) = self.genetic_code.get(codon) {
1021 if amino_acid == '*' {
1022 break; }
1024 protein.push(amino_acid);
1025 } else {
1026 protein.push('X'); }
1028 }
1029 }
1030
1031 Ok(protein)
1032 }
1033
1034 pub fn reverse_complement(&self, dna_sequence: &str) -> String {
1036 dna_sequence
1037 .chars()
1038 .rev()
1039 .map(|c| {
1040 if let Some(info) = self.nucleotides.get(&c.to_ascii_uppercase()) {
1041 info.complement
1042 } else {
1043 'N'
1044 }
1045 })
1046 .collect()
1047 }
1048}
1049
1050#[derive(Debug, Clone, PartialEq)]
1052#[allow(clippy::upper_case_acronyms)]
1053enum SequenceType {
1054 DNA,
1055 RNA,
1056 Protein,
1057 Structure,
1058 Unknown,
1059}
1060
1061impl Tokenizer for BioTokenizer {
1062 fn encode(&self, text: &str) -> Result<TokenizedInput> {
1063 let bio_tokens = self.tokenize_bio(text)?;
1064 let mut input_ids = Vec::new();
1065
1066 for token in bio_tokens {
1067 if let Some(&id) = self.vocab.get(&token.text) {
1068 input_ids.push(id);
1069 } else {
1070 if let Some(&unk_id) = self.vocab.get("[UNK]") {
1072 input_ids.push(unk_id);
1073 } else {
1074 input_ids.push(0); }
1076 }
1077 }
1078
1079 if let Some(max_len) = self.config.max_length {
1081 input_ids.truncate(max_len);
1082 }
1083
1084 let input_len = input_ids.len();
1085 Ok(TokenizedInput {
1086 input_ids,
1087 attention_mask: vec![1; input_len],
1088 token_type_ids: None,
1089 special_tokens_mask: None,
1090 offset_mapping: None,
1091 overflowing_tokens: None,
1092 })
1093 }
1094
1095 fn decode(&self, token_ids: &[u32]) -> Result<String> {
1096 let mut result = String::new();
1097
1098 for &id in token_ids {
1099 if let Some(token) = self.id_to_token.get(&id) {
1100 if !token.starts_with('[') || !token.ends_with(']') {
1101 result.push_str(token);
1102 }
1103 }
1104 }
1105
1106 Ok(result)
1107 }
1108
1109 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
1110 let mut tokenized_a = self.encode(text_a)?;
1111 let tokenized_b = self.encode(text_b)?;
1112
1113 if let Some(&sep_id) = self.vocab.get("[SEP]") {
1115 tokenized_a.input_ids.push(sep_id);
1116 }
1117
1118 tokenized_a.input_ids.extend(tokenized_b.input_ids);
1119
1120 if let Some(max_len) = self.config.max_length {
1122 tokenized_a.input_ids.truncate(max_len);
1123 }
1124
1125 Ok(tokenized_a)
1126 }
1127
1128 fn vocab_size(&self) -> usize {
1129 self.vocab.len()
1130 }
1131
1132 fn get_vocab(&self) -> HashMap<String, u32> {
1133 self.vocab.clone()
1134 }
1135
1136 fn token_to_id(&self, token: &str) -> Option<u32> {
1137 self.vocab.get(token).copied()
1138 }
1139
1140 fn id_to_token(&self, id: u32) -> Option<String> {
1141 self.id_to_token.get(&id).cloned()
1142 }
1143}
1144
1145pub struct BioAnalysis {
1147 pub token_types: HashMap<BioTokenType, usize>,
1149 pub amino_acid_composition: HashMap<char, usize>,
1151 pub nucleotide_composition: HashMap<char, usize>,
1153 pub gc_content: Option<f64>,
1155 pub molecular_weight: Option<f64>,
1157 pub avg_hydrophobicity: Option<f64>,
1159 pub net_charge: Option<i32>,
1161 pub kmer_diversity: f64,
1163 pub sequence_length: usize,
1165}
1166
1167impl BioTokenizer {
1168 pub fn analyze(&self, sequence: &str) -> Result<BioAnalysis> {
1170 let tokens = self.tokenize_bio(sequence)?;
1171
1172 let mut token_types = HashMap::new();
1173 let mut amino_acid_composition = HashMap::new();
1174 let mut nucleotide_composition = HashMap::new();
1175 let mut molecular_weight = 0.0;
1176 let mut total_hydrophobicity = 0.0;
1177 let mut net_charge = 0i32;
1178 let mut gc_count = 0;
1179 let mut nucleotide_count = 0;
1180 let mut protein_residue_count = 0;
1181
1182 for token in &tokens {
1183 *token_types.entry(token.token_type.clone()).or_insert(0) += 1;
1184
1185 if token.text.len() == 1 {
1186 let c = token
1187 .text
1188 .chars()
1189 .next()
1190 .expect("token.text with len()==1 must have at least one char");
1191
1192 match token.token_type {
1193 BioTokenType::AminoAcid => {
1194 *amino_acid_composition.entry(c).or_insert(0) += 1;
1195 if let Some(info) = self.amino_acids.get(&c) {
1196 molecular_weight += info.molecular_weight;
1197 total_hydrophobicity += info.hydrophobicity;
1198 net_charge += info.charge as i32;
1199 protein_residue_count += 1;
1200 }
1201 },
1202 BioTokenType::DNANucleotide | BioTokenType::RNANucleotide => {
1203 *nucleotide_composition.entry(c).or_insert(0) += 1;
1204 if "GC".contains(c) {
1205 gc_count += 1;
1206 }
1207 nucleotide_count += 1;
1208 },
1209 _ => {},
1210 }
1211 } else {
1212 if token.token_type == BioTokenType::Kmer {
1214 for c in token.text.chars() {
1215 if "ATGCU".contains(c) {
1217 *nucleotide_composition.entry(c).or_insert(0) += 1;
1218 if "GC".contains(c) {
1219 gc_count += 1;
1220 }
1221 nucleotide_count += 1;
1222 }
1223 else if self.amino_acids.contains_key(&c) {
1225 *amino_acid_composition.entry(c).or_insert(0) += 1;
1226 if let Some(info) = self.amino_acids.get(&c) {
1227 molecular_weight += info.molecular_weight;
1228 total_hydrophobicity += info.hydrophobicity;
1229 net_charge += info.charge as i32;
1230 protein_residue_count += 1;
1231 }
1232 }
1233 }
1234 }
1235 }
1236 }
1237
1238 let gc_content = if nucleotide_count > 0 {
1239 Some(gc_count as f64 / nucleotide_count as f64)
1240 } else {
1241 None
1242 };
1243
1244 let avg_hydrophobicity = if protein_residue_count > 0 {
1245 Some(total_hydrophobicity / protein_residue_count as f64)
1246 } else {
1247 None
1248 };
1249
1250 let molecular_weight_final =
1251 if protein_residue_count > 0 { Some(molecular_weight) } else { None };
1252
1253 let net_charge_final = if protein_residue_count > 0 { Some(net_charge) } else { None };
1254
1255 let total_tokens = tokens.len();
1257 let kmer_diversity = if total_tokens > 0 {
1258 let mut diversity = 0.0;
1259 for count in token_types.values() {
1260 let frequency = *count as f64 / total_tokens as f64;
1261 diversity += frequency * frequency;
1262 }
1263 1.0 - diversity
1264 } else {
1265 0.0
1266 };
1267
1268 Ok(BioAnalysis {
1269 token_types,
1270 amino_acid_composition,
1271 nucleotide_composition,
1272 gc_content,
1273 molecular_weight: molecular_weight_final,
1274 avg_hydrophobicity,
1275 net_charge: net_charge_final,
1276 kmer_diversity,
1277 sequence_length: sequence.len(),
1278 })
1279 }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284 use super::*;
1285
1286 #[test]
1287 fn test_bio_tokenizer_creation() {
1288 let tokenizer = BioTokenizer::new();
1289 assert!(tokenizer.get_vocab().len() > 0);
1290 assert!(tokenizer.get_vocab().contains_key("A"));
1291 assert!(tokenizer.get_vocab().contains_key("T"));
1292 }
1293
1294 #[test]
1295 fn test_sequence_type_detection() {
1296 let tokenizer = BioTokenizer::new();
1297 assert_eq!(
1298 tokenizer.detect_sequence_type("ATGCGATCG"),
1299 SequenceType::DNA
1300 );
1301 assert_eq!(
1302 tokenizer.detect_sequence_type("AUGCGAUCG"),
1303 SequenceType::RNA
1304 );
1305 assert_eq!(
1306 tokenizer.detect_sequence_type("MTKQVFTPG"),
1307 SequenceType::Protein
1308 );
1309 }
1310
1311 #[test]
1312 fn test_dna_encoding() {
1313 let tokenizer = BioTokenizer::new();
1314 let result = tokenizer.encode("ATGCGATCG");
1315 assert!(result.is_ok());
1316 let tokenized = result.expect("Operation failed in test");
1317 assert!(!tokenized.input_ids.is_empty());
1318 }
1319
1320 #[test]
1321 fn test_protein_encoding() {
1322 let tokenizer = BioTokenizer::new();
1323 let result = tokenizer.encode("MTKQVFTPG");
1324 assert!(result.is_ok());
1325 let tokenized = result.expect("Operation failed in test");
1326 assert!(!tokenized.input_ids.is_empty());
1327 }
1328
1329 #[test]
1330 fn test_kmer_tokenization() {
1331 let mut config = BioTokenizerConfig::default();
1332 config.kmer_size = Some(3);
1333 let tokenizer = BioTokenizer::with_config(config);
1334
1335 let tokens = tokenizer.tokenize_bio("ATGCGATCG").expect("Operation failed in test");
1336 assert!(tokens.iter().any(|t| t.text.len() == 3));
1337 }
1338
1339 #[test]
1340 fn test_translation() {
1341 let tokenizer = BioTokenizer::new();
1342 let protein = tokenizer.translate_dna("ATGAAATAG").expect("Operation failed in test");
1343 assert_eq!(protein, "MK"); }
1345
1346 #[test]
1347 fn test_reverse_complement() {
1348 let tokenizer = BioTokenizer::new();
1349 let rc = tokenizer.reverse_complement("ATGC");
1350 assert_eq!(rc, "GCAT");
1351 }
1352
1353 #[test]
1354 fn test_amino_acid_metadata() {
1355 let tokenizer = BioTokenizer::new();
1356 let metadata = tokenizer.create_amino_acid_metadata('A');
1357 assert!(metadata.is_some());
1358 let meta = metadata.expect("Operation failed in test");
1359 assert!(meta.molecular_weight.is_some());
1360 assert!(meta.hydrophobicity.is_some());
1361 }
1362
1363 #[test]
1364 fn test_nucleotide_metadata() {
1365 let tokenizer = BioTokenizer::new();
1366 let metadata = tokenizer.create_nucleotide_metadata('G');
1367 assert!(metadata.is_some());
1368 let meta = metadata.expect("Operation failed in test");
1369 assert_eq!(meta.gc_content, Some(1.0));
1370 }
1371
1372 #[test]
1373 fn test_bio_analysis() {
1374 let tokenizer = BioTokenizer::new();
1375 let analysis = tokenizer.analyze("ATGCGATCG");
1376 assert!(analysis.is_ok());
1377 let result = analysis.expect("Operation failed in test");
1378 assert!(result.gc_content.is_some());
1379 assert!(!result.nucleotide_composition.is_empty());
1380 }
1381
1382 #[test]
1383 fn test_protein_analysis() {
1384 let tokenizer = BioTokenizer::new();
1385 let analysis = tokenizer.analyze("MTKQVFTPG");
1386 assert!(analysis.is_ok());
1387 let result = analysis.expect("Operation failed in test");
1388 assert!(result.molecular_weight.is_some());
1389 assert!(result.avg_hydrophobicity.is_some());
1390 assert!(!result.amino_acid_composition.is_empty());
1391 }
1392
1393 #[test]
1394 fn test_stop_codon_detection() {
1395 let tokenizer = BioTokenizer::new();
1396 let tokens = tokenizer.tokenize_bio("ATGTAG").expect("Operation failed in test");
1397 assert!(tokens.iter().any(|t| t.token_type == BioTokenType::StartCodon));
1398 assert!(tokens.iter().any(|t| t.token_type == BioTokenType::StopCodon));
1399 }
1400
1401 #[test]
1402 fn test_max_length_constraint() {
1403 let mut config = BioTokenizerConfig::default();
1404 config.max_length = Some(5);
1405 let tokenizer = BioTokenizer::with_config(config);
1406
1407 let result = tokenizer.encode("ATGCGATCGATCGATCG");
1408 assert!(result.is_ok());
1409 let tokenized = result.expect("Operation failed in test");
1410 assert!(tokenized.input_ids.len() <= 5);
1411 }
1412}