1use crate::error::{Result, TextError};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufRead, BufReader, BufWriter, Write as IoWrite};
34use std::path::Path;
35
36pub trait TransformerTokenizer {
46 fn encode(&self, text: &str) -> Vec<u32>;
48
49 fn decode(&self, ids: &[u32]) -> String;
51
52 fn vocab_size(&self) -> usize;
54}
55
56fn pre_tokenize(text: &str) -> String {
62 let lower = text.to_lowercase();
63 lower.split_whitespace().collect::<Vec<&str>>().join(" ")
65}
66
67fn split_words(text: &str) -> Vec<String> {
69 text.split_whitespace().map(|w| w.to_string()).collect()
70}
71
72#[derive(Debug, Clone)]
98pub struct BPETokenizer {
99 vocab: HashMap<String, u32>,
101 id_to_token: HashMap<u32, String>,
103 merges: Vec<(String, String)>,
105 special_tokens: HashMap<String, u32>,
107}
108
109const DEFAULT_SPECIAL_TOKENS: &[&str] = &["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"];
111
112impl BPETokenizer {
113 pub fn new() -> Self {
115 let mut vocab = HashMap::new();
116 let mut id_to_token = HashMap::new();
117 let mut special_tokens = HashMap::new();
118
119 for (i, &tok) in DEFAULT_SPECIAL_TOKENS.iter().enumerate() {
120 let id = i as u32;
121 vocab.insert(tok.to_string(), id);
122 id_to_token.insert(id, tok.to_string());
123 special_tokens.insert(tok.to_string(), id);
124 }
125
126 Self {
127 vocab,
128 id_to_token,
129 merges: Vec::new(),
130 special_tokens,
131 }
132 }
133
134 pub fn train(texts: &[&str], vocab_size: usize) -> Result<Self> {
149 if texts.is_empty() {
150 return Err(TextError::TokenizationError(
151 "Cannot train on empty corpus".to_string(),
152 ));
153 }
154 if vocab_size < DEFAULT_SPECIAL_TOKENS.len() + 1 {
155 return Err(TextError::TokenizationError(format!(
156 "vocab_size must be at least {} (special tokens + 1)",
157 DEFAULT_SPECIAL_TOKENS.len() + 1
158 )));
159 }
160
161 let mut tokenizer = Self::new();
162
163 let mut char_set: Vec<char> = Vec::new();
165 for text in texts {
166 let normalized = pre_tokenize(text);
167 for ch in normalized.chars() {
168 if !char_set.contains(&ch) {
169 char_set.push(ch);
170 }
171 }
172 }
173 char_set.sort();
174
175 for ch in &char_set {
176 let s = ch.to_string();
177 if !tokenizer.vocab.contains_key(&s) {
178 let id = tokenizer.vocab.len() as u32;
179 tokenizer.vocab.insert(s.clone(), id);
180 tokenizer.id_to_token.insert(id, s);
181 }
182 }
183
184 let mut word_freqs: HashMap<Vec<String>, u64> = HashMap::new();
187 for text in texts {
188 let normalized = pre_tokenize(text);
189 for word in split_words(&normalized) {
190 let symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
191 *word_freqs.entry(symbols).or_insert(0) += 1;
192 }
193 }
194
195 let max_merges = vocab_size.saturating_sub(tokenizer.vocab.len());
197
198 for _ in 0..max_merges {
199 let mut pair_counts: HashMap<(String, String), u64> = HashMap::new();
201 for (symbols, freq) in &word_freqs {
202 if symbols.len() < 2 {
203 continue;
204 }
205 for pair in symbols.windows(2) {
206 let key = (pair[0].clone(), pair[1].clone());
207 *pair_counts.entry(key).or_insert(0) += freq;
208 }
209 }
210
211 let best = pair_counts
213 .iter()
214 .max_by_key(|&(_, &count)| count)
215 .map(|(pair, _)| pair.clone());
216
217 let best = match best {
218 Some(p) => p,
219 None => break, };
221
222 let merged = format!("{}{}", best.0, best.1);
223
224 tokenizer.merges.push(best.clone());
226 if !tokenizer.vocab.contains_key(&merged) {
227 let id = tokenizer.vocab.len() as u32;
228 tokenizer.vocab.insert(merged.clone(), id);
229 tokenizer.id_to_token.insert(id, merged.clone());
230 }
231
232 let mut new_word_freqs: HashMap<Vec<String>, u64> = HashMap::new();
234 for (symbols, freq) in &word_freqs {
235 let updated = apply_merge(symbols, &best.0, &best.1, &merged);
236 *new_word_freqs.entry(updated).or_insert(0) += freq;
237 }
238 word_freqs = new_word_freqs;
239 }
240
241 Ok(tokenizer)
242 }
243
244 fn unk_id(&self) -> u32 {
246 self.special_tokens.get("[UNK]").copied().unwrap_or(1)
247 }
248
249 fn encode_word(&self, word: &str) -> Vec<u32> {
251 if word.is_empty() {
252 return Vec::new();
253 }
254
255 let mut symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
256
257 for (left, right) in &self.merges {
259 let merged = format!("{}{}", left, right);
260 symbols = apply_merge(&symbols, left, right, &merged);
261 }
262
263 let unk = self.unk_id();
265 symbols
266 .iter()
267 .map(|s| self.vocab.get(s).copied().unwrap_or(unk))
268 .collect()
269 }
270
271 pub fn special_tokens(&self) -> &HashMap<String, u32> {
273 &self.special_tokens
274 }
275
276 pub fn special_token_id(&self, name: &str) -> Option<u32> {
278 self.special_tokens.get(name).copied()
279 }
280
281 pub fn add_special_token(&mut self, token: &str) -> u32 {
283 if let Some(&id) = self.vocab.get(token) {
284 self.special_tokens.insert(token.to_string(), id);
285 return id;
286 }
287 let id = self.vocab.len() as u32;
288 self.vocab.insert(token.to_string(), id);
289 self.id_to_token.insert(id, token.to_string());
290 self.special_tokens.insert(token.to_string(), id);
291 id
292 }
293
294 pub fn save_json(&self, path: &Path) -> Result<()> {
298 let file = File::create(path).map_err(|e| TextError::IoError(format!("save_json: {e}")))?;
299 let writer = BufWriter::new(file);
300
301 write_bpe_json(writer, &self.vocab, &self.merges, &self.special_tokens)
304 }
305
306 pub fn load_json(path: &Path) -> Result<Self> {
308 let file = File::open(path).map_err(|e| TextError::IoError(format!("load_json: {e}")))?;
309 let reader = BufReader::new(file);
310 read_bpe_json(reader)
311 }
312}
313
314impl Default for BPETokenizer {
315 fn default() -> Self {
316 Self::new()
317 }
318}
319
320impl TransformerTokenizer for BPETokenizer {
321 fn encode(&self, text: &str) -> Vec<u32> {
322 let normalized = pre_tokenize(text);
323 let words = split_words(&normalized);
324 let mut ids = Vec::new();
325 for word in &words {
326 ids.extend(self.encode_word(word));
327 }
328 ids
329 }
330
331 fn decode(&self, ids: &[u32]) -> String {
332 let tokens: Vec<String> = ids
333 .iter()
334 .filter_map(|&id| self.id_to_token.get(&id).cloned())
335 .collect();
336 rejoin_bpe_tokens(&tokens)
344 }
345
346 fn vocab_size(&self) -> usize {
347 self.vocab.len()
348 }
349}
350
351fn apply_merge(symbols: &[String], left: &str, right: &str, merged: &str) -> Vec<String> {
357 let mut result = Vec::with_capacity(symbols.len());
358 let mut i = 0;
359 while i < symbols.len() {
360 if i + 1 < symbols.len() && symbols[i] == left && symbols[i + 1] == right {
361 result.push(merged.to_string());
362 i += 2;
363 } else {
364 result.push(symbols[i].clone());
365 i += 1;
366 }
367 }
368 result
369}
370
371fn rejoin_bpe_tokens(tokens: &[String]) -> String {
379 if tokens.is_empty() {
380 return String::new();
381 }
382 let joined: String = tokens.concat();
387 joined
390}
391
392fn write_bpe_json<W: IoWrite>(
394 mut w: W,
395 vocab: &HashMap<String, u32>,
396 merges: &[(String, String)],
397 special_tokens: &HashMap<String, u32>,
398) -> Result<()> {
399 let write_err = |e: std::io::Error| TextError::IoError(format!("write_bpe_json: {e}"));
400
401 w.write_all(b"{\n").map_err(write_err)?;
402
403 w.write_all(b" \"vocab\": {\n").map_err(write_err)?;
405 let mut sorted_vocab: Vec<(&String, &u32)> = vocab.iter().collect();
406 sorted_vocab.sort_by_key(|&(_, id)| *id);
407 for (idx, (token, id)) in sorted_vocab.iter().enumerate() {
408 let comma = if idx + 1 < sorted_vocab.len() {
409 ","
410 } else {
411 ""
412 };
413 let escaped = escape_json_string(token);
414 writeln!(w, " \"{}\": {}{}", escaped, id, comma).map_err(write_err)?;
415 }
416 w.write_all(b" },\n").map_err(write_err)?;
417
418 w.write_all(b" \"merges\": [\n").map_err(write_err)?;
420 for (idx, (left, right)) in merges.iter().enumerate() {
421 let comma = if idx + 1 < merges.len() { "," } else { "" };
422 let left_esc = escape_json_string(left);
423 let right_esc = escape_json_string(right);
424 writeln!(w, " [\"{}\", \"{}\"]{}", left_esc, right_esc, comma).map_err(write_err)?;
425 }
426 w.write_all(b" ],\n").map_err(write_err)?;
427
428 w.write_all(b" \"special_tokens\": {\n")
430 .map_err(write_err)?;
431 let mut sorted_special: Vec<(&String, &u32)> = special_tokens.iter().collect();
432 sorted_special.sort_by_key(|&(_, id)| *id);
433 for (idx, (token, id)) in sorted_special.iter().enumerate() {
434 let comma = if idx + 1 < sorted_special.len() {
435 ","
436 } else {
437 ""
438 };
439 let escaped = escape_json_string(token);
440 writeln!(w, " \"{}\": {}{}", escaped, id, comma).map_err(write_err)?;
441 }
442 w.write_all(b" }\n").map_err(write_err)?;
443
444 w.write_all(b"}\n").map_err(write_err)?;
445 Ok(())
446}
447
448fn read_bpe_json<R: BufRead>(reader: R) -> Result<BPETokenizer> {
453 let content: String = reader
454 .lines()
455 .collect::<std::result::Result<Vec<_>, _>>()
456 .map_err(|e| TextError::IoError(format!("read_bpe_json: {e}")))?
457 .join("\n");
458
459 let mut vocab: HashMap<String, u32> = HashMap::new();
460 let mut id_to_token: HashMap<u32, String> = HashMap::new();
461 let mut merges: Vec<(String, String)> = Vec::new();
462 let mut special_tokens: HashMap<String, u32> = HashMap::new();
463
464 if let Some(vocab_section) = extract_json_object(&content, "vocab") {
466 for (key, val) in parse_string_int_pairs(&vocab_section) {
467 vocab.insert(key.clone(), val);
468 id_to_token.insert(val, key);
469 }
470 }
471
472 if let Some(merges_section) = extract_json_array(&content, "merges") {
474 merges = parse_merge_pairs(&merges_section);
475 }
476
477 if let Some(special_section) = extract_json_object(&content, "special_tokens") {
479 for (key, val) in parse_string_int_pairs(&special_section) {
480 special_tokens.insert(key, val);
481 }
482 }
483
484 Ok(BPETokenizer {
485 vocab,
486 id_to_token,
487 merges,
488 special_tokens,
489 })
490}
491
492fn escape_json_string(s: &str) -> String {
494 let mut out = String::with_capacity(s.len() + 2);
495 for ch in s.chars() {
496 match ch {
497 '"' => out.push_str("\\\""),
498 '\\' => out.push_str("\\\\"),
499 '\n' => out.push_str("\\n"),
500 '\r' => out.push_str("\\r"),
501 '\t' => out.push_str("\\t"),
502 c if c.is_control() => {
503 out.push_str(&format!("\\u{:04x}", c as u32));
504 }
505 c => out.push(c),
506 }
507 }
508 out
509}
510
511fn unescape_json_string(s: &str) -> String {
513 let mut out = String::with_capacity(s.len());
514 let mut chars = s.chars();
515 while let Some(ch) = chars.next() {
516 if ch == '\\' {
517 match chars.next() {
518 Some('"') => out.push('"'),
519 Some('\\') => out.push('\\'),
520 Some('n') => out.push('\n'),
521 Some('r') => out.push('\r'),
522 Some('t') => out.push('\t'),
523 Some('u') => {
524 let hex: String = chars.by_ref().take(4).collect();
525 if let Ok(code) = u32::from_str_radix(&hex, 16) {
526 if let Some(c) = char::from_u32(code) {
527 out.push(c);
528 }
529 }
530 }
531 Some(other) => {
532 out.push('\\');
533 out.push(other);
534 }
535 None => out.push('\\'),
536 }
537 } else {
538 out.push(ch);
539 }
540 }
541 out
542}
543
544fn extract_json_object(json: &str, key: &str) -> Option<String> {
546 let pattern = format!("\"{}\"", key);
547 let start = json.find(&pattern)?;
548 let after_key = &json[start + pattern.len()..];
549 let brace_start = after_key.find('{')?;
551 let content_start = start + pattern.len() + brace_start;
552
553 let mut depth = 0;
555 for (i, ch) in json[content_start..].chars().enumerate() {
556 match ch {
557 '{' => depth += 1,
558 '}' => {
559 depth -= 1;
560 if depth == 0 {
561 return Some(json[content_start + 1..content_start + i].to_string());
562 }
563 }
564 _ => {}
565 }
566 }
567 None
568}
569
570fn extract_json_array(json: &str, key: &str) -> Option<String> {
572 let pattern = format!("\"{}\"", key);
573 let start = json.find(&pattern)?;
574 let after_key = &json[start + pattern.len()..];
575 let bracket_start = after_key.find('[')?;
576 let content_start = start + pattern.len() + bracket_start;
577
578 let mut depth = 0;
579 for (i, ch) in json[content_start..].chars().enumerate() {
580 match ch {
581 '[' => depth += 1,
582 ']' => {
583 depth -= 1;
584 if depth == 0 {
585 return Some(json[content_start + 1..content_start + i].to_string());
586 }
587 }
588 _ => {}
589 }
590 }
591 None
592}
593
594fn parse_string_int_pairs(content: &str) -> Vec<(String, u32)> {
596 let mut pairs = Vec::new();
597 let mut remaining = content.trim();
598
599 while !remaining.is_empty() {
600 let q1 = match remaining.find('"') {
602 Some(pos) => pos,
603 None => break,
604 };
605 let after_q1 = &remaining[q1 + 1..];
606 let q2 = match find_unescaped_quote(after_q1) {
608 Some(pos) => pos,
609 None => break,
610 };
611 let key = unescape_json_string(&after_q1[..q2]);
612 let after_key = &after_q1[q2 + 1..];
613
614 let colon = match after_key.find(':') {
616 Some(pos) => pos,
617 None => break,
618 };
619 let after_colon = after_key[colon + 1..].trim_start();
620
621 let num_end = after_colon
623 .find(|c: char| !c.is_ascii_digit())
624 .unwrap_or(after_colon.len());
625 if let Ok(val) = after_colon[..num_end].parse::<u32>() {
626 pairs.push((key, val));
627 }
628
629 let consumed = after_colon[num_end..].trim_start();
631 remaining = if consumed.starts_with(',') {
632 &consumed[1..]
633 } else {
634 consumed
635 };
636 }
637 pairs
638}
639
640fn parse_merge_pairs(content: &str) -> Vec<(String, String)> {
642 let mut pairs = Vec::new();
643 let mut remaining = content.trim();
644
645 while !remaining.is_empty() {
646 let bracket = match remaining.find('[') {
648 Some(pos) => pos,
649 None => break,
650 };
651 let end_bracket = match remaining[bracket..].find(']') {
652 Some(pos) => bracket + pos,
653 None => break,
654 };
655 let inner = &remaining[bracket + 1..end_bracket];
656
657 let mut strings = Vec::new();
659 let mut inner_rem = inner.trim();
660 for _ in 0..2 {
661 let q1 = match inner_rem.find('"') {
662 Some(pos) => pos,
663 None => break,
664 };
665 let after_q1 = &inner_rem[q1 + 1..];
666 let q2 = match find_unescaped_quote(after_q1) {
667 Some(pos) => pos,
668 None => break,
669 };
670 strings.push(unescape_json_string(&after_q1[..q2]));
671 inner_rem = &after_q1[q2 + 1..];
672 inner_rem = inner_rem.trim_start();
673 if inner_rem.starts_with(',') {
674 inner_rem = &inner_rem[1..];
675 }
676 }
677
678 if strings.len() == 2 {
679 pairs.push((strings[0].clone(), strings[1].clone()));
680 }
681
682 remaining = &remaining[end_bracket + 1..];
683 remaining = remaining.trim_start();
684 if remaining.starts_with(',') {
685 remaining = &remaining[1..];
686 }
687 }
688 pairs
689}
690
691fn find_unescaped_quote(s: &str) -> Option<usize> {
693 let mut escape = false;
694 for (i, ch) in s.chars().enumerate() {
695 if escape {
696 escape = false;
697 continue;
698 }
699 if ch == '\\' {
700 escape = true;
701 continue;
702 }
703 if ch == '"' {
704 return Some(i);
705 }
706 }
707 None
708}
709
710#[derive(Debug, Clone)]
738pub struct WordPieceTokenizer {
739 vocab: HashMap<String, u32>,
741 id_to_token: HashMap<u32, String>,
743 max_word_len: usize,
745 unk_token: String,
747 continuing_prefix: String,
749}
750
751impl WordPieceTokenizer {
752 pub fn new(vocab: HashMap<String, u32>) -> Self {
757 let id_to_token: HashMap<u32, String> =
758 vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
759 Self {
760 vocab,
761 id_to_token,
762 max_word_len: 200,
763 unk_token: "[UNK]".to_string(),
764 continuing_prefix: "##".to_string(),
765 }
766 }
767
768 pub fn with_max_word_len(mut self, max_len: usize) -> Self {
770 self.max_word_len = max_len;
771 self
772 }
773
774 pub fn with_unk_token(mut self, unk: &str) -> Self {
776 self.unk_token = unk.to_string();
777 self
778 }
779
780 pub fn with_continuing_prefix(mut self, prefix: &str) -> Self {
782 self.continuing_prefix = prefix.to_string();
783 self
784 }
785
786 pub fn from_vocab_file(path: &Path) -> Result<Self> {
790 let file =
791 File::open(path).map_err(|e| TextError::IoError(format!("from_vocab_file: {e}")))?;
792 let reader = BufReader::new(file);
793
794 let mut vocab = HashMap::new();
795 for (id, line) in reader.lines().enumerate() {
796 let line =
797 line.map_err(|e| TextError::IoError(format!("from_vocab_file read: {e}")))?;
798 let token = line.trim().to_string();
799 if !token.is_empty() {
800 vocab.insert(token, id as u32);
801 }
802 }
803
804 if vocab.is_empty() {
805 return Err(TextError::TokenizationError(
806 "Vocabulary file is empty".to_string(),
807 ));
808 }
809
810 Ok(Self::new(vocab))
811 }
812
813 pub fn tokenize(&self, text: &str) -> Vec<String> {
818 let normalized = pre_tokenize(text);
819 let words = split_words(&normalized);
820 let mut tokens = Vec::new();
821
822 for word in &words {
823 if word.len() > self.max_word_len {
824 tokens.push(self.unk_token.clone());
825 continue;
826 }
827 let word_tokens = self.tokenize_word(word);
828 tokens.extend(word_tokens);
829 }
830 tokens
831 }
832
833 fn tokenize_word(&self, word: &str) -> Vec<String> {
835 let chars: Vec<char> = word.chars().collect();
836 let mut tokens = Vec::new();
837 let mut start = 0;
838
839 while start < chars.len() {
840 let mut end = chars.len();
841 let mut found = false;
842
843 while start < end {
844 let substr: String = chars[start..end].iter().collect();
845 let candidate = if start == 0 {
846 substr.clone()
847 } else {
848 format!("{}{}", self.continuing_prefix, substr)
849 };
850
851 if self.vocab.contains_key(&candidate) {
852 tokens.push(candidate);
853 found = true;
854 break;
855 }
856 end -= 1;
857 }
858
859 if !found {
860 tokens.push(self.unk_token.clone());
862 start += 1;
863 } else {
864 start = end;
865 }
866 }
867
868 tokens
869 }
870
871 fn unk_id(&self) -> u32 {
873 self.vocab.get(&self.unk_token).copied().unwrap_or(0)
874 }
875}
876
877impl TransformerTokenizer for WordPieceTokenizer {
878 fn encode(&self, text: &str) -> Vec<u32> {
879 let tokens = self.tokenize(text);
880 let unk = self.unk_id();
881 tokens
882 .iter()
883 .map(|t| self.vocab.get(t).copied().unwrap_or(unk))
884 .collect()
885 }
886
887 fn decode(&self, ids: &[u32]) -> String {
888 let mut result = String::new();
889 let mut first_in_word = true;
890
891 for &id in ids {
892 let token = match self.id_to_token.get(&id) {
893 Some(t) => t.as_str(),
894 None => &self.unk_token,
895 };
896
897 if let Some(stripped) = token.strip_prefix(&self.continuing_prefix) {
898 result.push_str(stripped);
900 } else {
901 if !first_in_word {
903 result.push(' ');
904 }
905 result.push_str(token);
906 }
907 first_in_word = false;
908 }
909 result
910 }
911
912 fn vocab_size(&self) -> usize {
913 self.vocab.len()
914 }
915}
916
917#[derive(Debug, Clone)]
938pub struct SimpleWhitespaceTokenizer {
939 vocab: HashMap<String, u32>,
941 id_to_token: HashMap<u32, String>,
943 unk_id: u32,
945}
946
947impl SimpleWhitespaceTokenizer {
948 pub fn build(texts: &[&str], max_vocab: usize) -> Self {
953 let mut word_counts: HashMap<String, u64> = HashMap::new();
954 for text in texts {
955 let normalized = pre_tokenize(text);
956 for word in split_words(&normalized) {
957 *word_counts.entry(word).or_insert(0) += 1;
958 }
959 }
960
961 let mut sorted: Vec<(String, u64)> = word_counts.into_iter().collect();
963 sorted.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
964
965 let mut vocab = HashMap::new();
966 let mut id_to_token = HashMap::new();
967
968 vocab.insert("[UNK]".to_string(), 0);
970 id_to_token.insert(0, "[UNK]".to_string());
971
972 let limit = max_vocab.saturating_sub(1); for (word, _) in sorted.into_iter().take(limit) {
974 let id = vocab.len() as u32;
975 id_to_token.insert(id, word.clone());
976 vocab.insert(word, id);
977 }
978
979 Self {
980 vocab,
981 id_to_token,
982 unk_id: 0,
983 }
984 }
985
986 pub fn from_vocab(vocab: HashMap<String, u32>, unk_id: u32) -> Self {
988 let id_to_token: HashMap<u32, String> =
989 vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
990 Self {
991 vocab,
992 id_to_token,
993 unk_id,
994 }
995 }
996}
997
998impl TransformerTokenizer for SimpleWhitespaceTokenizer {
999 fn encode(&self, text: &str) -> Vec<u32> {
1000 let normalized = pre_tokenize(text);
1001 split_words(&normalized)
1002 .iter()
1003 .map(|w| self.vocab.get(w).copied().unwrap_or(self.unk_id))
1004 .collect()
1005 }
1006
1007 fn decode(&self, ids: &[u32]) -> String {
1008 ids.iter()
1009 .filter_map(|&id| self.id_to_token.get(&id).cloned())
1010 .collect::<Vec<String>>()
1011 .join(" ")
1012 }
1013
1014 fn vocab_size(&self) -> usize {
1015 self.vocab.len()
1016 }
1017}
1018
1019#[derive(Debug, Clone)]
1040pub struct SimpleCharTokenizer {
1041 vocab: HashMap<char, u32>,
1043 id_to_char: HashMap<u32, char>,
1045 unk_id: u32,
1047}
1048
1049impl SimpleCharTokenizer {
1050 pub fn build(texts: &[&str]) -> Self {
1054 let mut char_set: Vec<char> = Vec::new();
1055 for text in texts {
1056 for ch in text.chars() {
1057 if !char_set.contains(&ch) {
1058 char_set.push(ch);
1059 }
1060 }
1061 }
1062 char_set.sort();
1063
1064 let mut vocab = HashMap::new();
1065 let mut id_to_char = HashMap::new();
1066 let unk_id = 0_u32;
1068
1069 for ch in char_set {
1070 let id = (vocab.len() + 1) as u32; vocab.insert(ch, id);
1072 id_to_char.insert(id, ch);
1073 }
1074
1075 Self {
1076 vocab,
1077 id_to_char,
1078 unk_id,
1079 }
1080 }
1081
1082 pub fn from_vocab(vocab: HashMap<char, u32>, unk_id: u32) -> Self {
1084 let id_to_char: HashMap<u32, char> = vocab.iter().map(|(&c, &id)| (id, c)).collect();
1085 Self {
1086 vocab,
1087 id_to_char,
1088 unk_id,
1089 }
1090 }
1091}
1092
1093impl TransformerTokenizer for SimpleCharTokenizer {
1094 fn encode(&self, text: &str) -> Vec<u32> {
1095 text.chars()
1096 .map(|ch| self.vocab.get(&ch).copied().unwrap_or(self.unk_id))
1097 .collect()
1098 }
1099
1100 fn decode(&self, ids: &[u32]) -> String {
1101 ids.iter()
1102 .filter_map(|&id| self.id_to_char.get(&id).copied())
1103 .collect()
1104 }
1105
1106 fn vocab_size(&self) -> usize {
1107 self.vocab.len() + 1 }
1109}
1110
1111#[cfg(test)]
1116mod tests {
1117 use super::*;
1118
1119 #[test]
1122 fn test_bpe_train_basic() {
1123 let corpus = &[
1124 "the cat sat on the mat",
1125 "the dog sat on the log",
1126 "the cat and the dog",
1127 ];
1128 let tok = BPETokenizer::train(corpus, 80).expect("train should succeed");
1129 assert!(tok.vocab_size() > 0);
1130 assert!(tok.vocab_size() <= 80);
1131 }
1132
1133 #[test]
1134 fn test_bpe_train_empty_corpus_error() {
1135 let result = BPETokenizer::train(&[], 100);
1136 assert!(result.is_err());
1137 }
1138
1139 #[test]
1140 fn test_bpe_train_small_vocab_error() {
1141 let corpus = &["hello"];
1142 let result = BPETokenizer::train(corpus, 3);
1144 assert!(result.is_err());
1145 }
1146
1147 #[test]
1148 fn test_bpe_encode_decode_roundtrip() {
1149 let corpus = &["hello world", "hello there world", "the world is great"];
1150 let tok = BPETokenizer::train(corpus, 100).expect("train");
1151
1152 let text = "hello world";
1153 let ids = tok.encode(text);
1154 assert!(!ids.is_empty());
1155
1156 let decoded = tok.decode(&ids);
1157 assert!(decoded.contains("hello"));
1160 assert!(decoded.contains("world"));
1161 }
1162
1163 #[test]
1164 fn test_bpe_encode_empty() {
1165 let corpus = &["hello world"];
1166 let tok = BPETokenizer::train(corpus, 50).expect("train");
1167 let ids = tok.encode("");
1168 assert!(ids.is_empty());
1169 }
1170
1171 #[test]
1172 fn test_bpe_special_tokens() {
1173 let tok = BPETokenizer::new();
1174 assert!(tok.special_token_id("[PAD]").is_some());
1175 assert!(tok.special_token_id("[UNK]").is_some());
1176 assert!(tok.special_token_id("[CLS]").is_some());
1177 assert!(tok.special_token_id("[SEP]").is_some());
1178 assert!(tok.special_token_id("[MASK]").is_some());
1179 }
1180
1181 #[test]
1182 fn test_bpe_add_special_token() {
1183 let mut tok = BPETokenizer::new();
1184 let id = tok.add_special_token("[BOS]");
1185 assert_eq!(tok.special_token_id("[BOS]"), Some(id));
1186 assert_eq!(tok.vocab_size(), DEFAULT_SPECIAL_TOKENS.len() + 1);
1187 }
1188
1189 #[test]
1190 fn test_bpe_save_load_json() {
1191 let corpus = &["the cat sat on the mat", "the dog sat on the log"];
1192 let tok = BPETokenizer::train(corpus, 60).expect("train");
1193
1194 let dir = std::env::temp_dir();
1195 let path = dir.join("test_bpe_tokenizer.json");
1196
1197 tok.save_json(&path).expect("save");
1198 let loaded = BPETokenizer::load_json(&path).expect("load");
1199
1200 assert_eq!(tok.vocab_size(), loaded.vocab_size());
1201 assert_eq!(tok.merges.len(), loaded.merges.len());
1202
1203 let text = "the cat sat";
1205 let ids1 = tok.encode(text);
1206 let ids2 = loaded.encode(text);
1207 assert_eq!(ids1, ids2);
1208
1209 let _ = std::fs::remove_file(&path);
1211 }
1212
1213 #[test]
1214 fn test_bpe_unknown_chars() {
1215 let corpus = &["abc"];
1216 let tok = BPETokenizer::train(corpus, 30).expect("train");
1217
1218 let ids = tok.encode("xyz");
1220 let unk = tok.unk_id();
1222 assert!(ids.iter().all(|&id| id == unk));
1223 }
1224
1225 #[test]
1226 fn test_bpe_default_constructor() {
1227 let tok = BPETokenizer::default();
1228 assert_eq!(tok.vocab_size(), DEFAULT_SPECIAL_TOKENS.len());
1229 assert!(tok.merges.is_empty());
1230 }
1231
1232 #[test]
1233 fn test_bpe_vocab_size_trait() {
1234 let corpus = &["hello world hello"];
1235 let tok = BPETokenizer::train(corpus, 50).expect("train");
1236 let trait_ref: &dyn TransformerTokenizer = &tok;
1237 assert!(trait_ref.vocab_size() > 0);
1238 }
1239
1240 #[test]
1243 fn test_wordpiece_basic() {
1244 let mut vocab = HashMap::new();
1245 vocab.insert("[UNK]".to_string(), 0);
1246 vocab.insert("hello".to_string(), 1);
1247 vocab.insert("world".to_string(), 2);
1248 vocab.insert("hel".to_string(), 3);
1249 vocab.insert("##lo".to_string(), 4);
1250 vocab.insert("wor".to_string(), 5);
1251 vocab.insert("##ld".to_string(), 6);
1252
1253 let tok = WordPieceTokenizer::new(vocab);
1254 let tokens = tok.tokenize("hello world");
1255
1256 assert!(tokens.contains(&"hello".to_string()));
1258 assert!(tokens.contains(&"world".to_string()));
1259 }
1260
1261 #[test]
1262 fn test_wordpiece_subword_split() {
1263 let mut vocab = HashMap::new();
1264 vocab.insert("[UNK]".to_string(), 0);
1265 vocab.insert("un".to_string(), 1);
1266 vocab.insert("##aff".to_string(), 2);
1267 vocab.insert("##able".to_string(), 3);
1268
1269 let tok = WordPieceTokenizer::new(vocab);
1270 let tokens = tok.tokenize("unaffable");
1271 assert_eq!(tokens, vec!["un", "##aff", "##able"]);
1273 }
1274
1275 #[test]
1276 fn test_wordpiece_unknown_word() {
1277 let mut vocab = HashMap::new();
1278 vocab.insert("[UNK]".to_string(), 0);
1279 vocab.insert("hello".to_string(), 1);
1280
1281 let tok = WordPieceTokenizer::new(vocab);
1282 let tokens = tok.tokenize("xyz");
1283 assert!(tokens.contains(&"[UNK]".to_string()));
1285 }
1286
1287 #[test]
1288 fn test_wordpiece_encode_decode() {
1289 let mut vocab = HashMap::new();
1290 vocab.insert("[UNK]".to_string(), 0);
1291 vocab.insert("play".to_string(), 1);
1292 vocab.insert("##ing".to_string(), 2);
1293 vocab.insert("##er".to_string(), 3);
1294 vocab.insert("##s".to_string(), 4);
1295 vocab.insert("the".to_string(), 5);
1296
1297 let tok = WordPieceTokenizer::new(vocab);
1298
1299 let ids = tok.encode("the playing");
1300 assert!(!ids.is_empty());
1301
1302 let decoded = tok.decode(&ids);
1303 assert!(decoded.contains("the"));
1304 assert!(decoded.contains("play"));
1305 assert!(decoded.contains("ing"));
1306 }
1307
1308 #[test]
1309 fn test_wordpiece_max_word_len() {
1310 let mut vocab = HashMap::new();
1311 vocab.insert("[UNK]".to_string(), 0);
1312 vocab.insert("a".to_string(), 1);
1313
1314 let tok = WordPieceTokenizer::new(vocab).with_max_word_len(5);
1315
1316 let tokens = tok.tokenize("toolongword");
1318 assert_eq!(tokens, vec!["[UNK]"]);
1319 }
1320
1321 #[test]
1322 fn test_wordpiece_custom_prefix() {
1323 let mut vocab = HashMap::new();
1324 vocab.insert("[UNK]".to_string(), 0);
1325 vocab.insert("hel".to_string(), 1);
1326 vocab.insert("@@lo".to_string(), 2);
1327
1328 let tok = WordPieceTokenizer::new(vocab).with_continuing_prefix("@@");
1329 let tokens = tok.tokenize("hello");
1330 assert_eq!(tokens, vec!["hel", "@@lo"]);
1331 }
1332
1333 #[test]
1334 fn test_wordpiece_from_vocab_file() {
1335 let dir = std::env::temp_dir();
1336 let path = dir.join("test_wp_vocab.txt");
1337
1338 {
1340 let mut f = File::create(&path).expect("create vocab file");
1341 writeln!(f, "[UNK]").expect("write");
1342 writeln!(f, "[PAD]").expect("write");
1343 writeln!(f, "hello").expect("write");
1344 writeln!(f, "world").expect("write");
1345 writeln!(f, "##ing").expect("write");
1346 }
1347
1348 let tok = WordPieceTokenizer::from_vocab_file(&path).expect("load vocab");
1349 assert_eq!(tok.vocab_size(), 5);
1350
1351 let _ = std::fs::remove_file(&path);
1352 }
1353
1354 #[test]
1355 fn test_wordpiece_vocab_size() {
1356 let mut vocab = HashMap::new();
1357 vocab.insert("[UNK]".to_string(), 0);
1358 vocab.insert("a".to_string(), 1);
1359 vocab.insert("b".to_string(), 2);
1360
1361 let tok = WordPieceTokenizer::new(vocab);
1362 assert_eq!(tok.vocab_size(), 3);
1363 }
1364
1365 #[test]
1368 fn test_whitespace_build_and_encode() {
1369 let texts = &["hello world", "hello there", "world peace"];
1370 let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1371
1372 let ids = tok.encode("hello world");
1373 assert_eq!(ids.len(), 2);
1374
1375 assert_ne!(ids[0], ids[1]);
1377 }
1378
1379 #[test]
1380 fn test_whitespace_decode() {
1381 let texts = &["hello world", "foo bar"];
1382 let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1383
1384 let ids = tok.encode("hello world");
1385 let decoded = tok.decode(&ids);
1386 assert_eq!(decoded, "hello world");
1387 }
1388
1389 #[test]
1390 fn test_whitespace_unknown_word() {
1391 let texts = &["hello world"];
1392 let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1393
1394 let ids = tok.encode("hello xyz");
1395 assert_eq!(ids[1], 0);
1397 }
1398
1399 #[test]
1400 fn test_whitespace_max_vocab_limit() {
1401 let texts = &["a b c d e f g"];
1402 let tok = SimpleWhitespaceTokenizer::build(texts, 4); assert!(tok.vocab_size() <= 4);
1404 }
1405
1406 #[test]
1407 fn test_whitespace_vocab_size() {
1408 let texts = &["one two three"];
1409 let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1410 assert_eq!(tok.vocab_size(), 4);
1412 }
1413
1414 #[test]
1417 fn test_char_build_and_encode() {
1418 let texts = &["abc", "bcd"];
1419 let tok = SimpleCharTokenizer::build(texts);
1420
1421 let ids = tok.encode("abc");
1422 assert_eq!(ids.len(), 3);
1423 assert!(ids.iter().all(|&id| id > 0));
1425 }
1426
1427 #[test]
1428 fn test_char_decode() {
1429 let texts = &["hello"];
1430 let tok = SimpleCharTokenizer::build(texts);
1431
1432 let ids = tok.encode("hello");
1433 let decoded = tok.decode(&ids);
1434 assert_eq!(decoded, "hello");
1435 }
1436
1437 #[test]
1438 fn test_char_unknown_char() {
1439 let texts = &["abc"];
1440 let tok = SimpleCharTokenizer::build(texts);
1441
1442 let ids = tok.encode("xyz");
1443 assert!(ids.iter().all(|&id| id == 0));
1445 }
1446
1447 #[test]
1448 fn test_char_vocab_size() {
1449 let texts = &["ab", "bc"];
1450 let tok = SimpleCharTokenizer::build(texts);
1451 assert_eq!(tok.vocab_size(), 4);
1453 }
1454
1455 #[test]
1456 fn test_char_roundtrip() {
1457 let texts = &["The quick brown fox!"];
1458 let tok = SimpleCharTokenizer::build(texts);
1459
1460 let original = "The quick brown fox!";
1461 let ids = tok.encode(original);
1462 let decoded = tok.decode(&ids);
1463 assert_eq!(decoded, original);
1464 }
1465
1466 #[test]
1469 fn test_trait_object_dispatch() {
1470 let corpus = &["hello world hello"];
1471 let bpe = BPETokenizer::train(corpus, 50).expect("train");
1472
1473 let mut vocab = HashMap::new();
1474 vocab.insert("[UNK]".to_string(), 0);
1475 vocab.insert("hello".to_string(), 1);
1476 vocab.insert("world".to_string(), 2);
1477 let wp = WordPieceTokenizer::new(vocab);
1478
1479 let ws_texts = &["hello world"];
1480 let ws = SimpleWhitespaceTokenizer::build(ws_texts, 50);
1481
1482 let char_texts = &["hello world"];
1483 let ch = SimpleCharTokenizer::build(char_texts);
1484
1485 let tokenizers: Vec<&dyn TransformerTokenizer> = vec![&bpe, &wp, &ws, &ch];
1487 for tok in tokenizers {
1488 assert!(tok.vocab_size() > 0);
1489 let ids = tok.encode("hello");
1490 assert!(!ids.is_empty());
1491 let _ = tok.decode(&ids);
1492 }
1493 }
1494
1495 #[test]
1498 fn test_json_escape_roundtrip() {
1499 let original = "hello \"world\"\nnewline\\backslash\ttab";
1500 let escaped = escape_json_string(original);
1501 let unescaped = unescape_json_string(&escaped);
1502 assert_eq!(original, unescaped);
1503 }
1504
1505 #[test]
1506 fn test_bpe_multiple_sentences() {
1507 let corpus = &[
1508 "machine learning is transforming the world",
1509 "deep learning models use transformers",
1510 "natural language processing with transformers",
1511 "the transformer architecture is powerful",
1512 ];
1513 let tok = BPETokenizer::train(corpus, 120).expect("train");
1514
1515 let text = "learning transformers";
1516 let ids = tok.encode(text);
1517 assert!(!ids.is_empty());
1518
1519 let unk = tok.unk_id();
1521 assert!(ids.iter().all(|&id| id != unk));
1523 }
1524
1525 #[test]
1526 fn test_bpe_merges_reduce_token_count() {
1527 let corpus = &["aaaa aaaa aaaa aaaa aaaa", "aaaa aaaa aaaa aaaa aaaa"];
1528 let tok = BPETokenizer::train(corpus, 50).expect("train");
1529
1530 let ids = tok.encode("aaaa");
1533 assert!(
1534 ids.len() < 4,
1535 "BPE should merge repeated chars: got {} tokens",
1536 ids.len()
1537 );
1538 }
1539
1540 #[test]
1541 fn test_wordpiece_empty_input() {
1542 let mut vocab = HashMap::new();
1543 vocab.insert("[UNK]".to_string(), 0);
1544 let tok = WordPieceTokenizer::new(vocab);
1545
1546 let tokens = tok.tokenize("");
1547 assert!(tokens.is_empty());
1548
1549 let ids = tok.encode("");
1550 assert!(ids.is_empty());
1551
1552 let decoded = tok.decode(&[]);
1553 assert!(decoded.is_empty());
1554 }
1555}