1use crate::error::{Result, TextError};
9use std::collections::HashMap;
10use std::io::{BufRead, BufReader, Write};
11
12type InitBaseMaps = (
14 HashMap<u8, char>,
15 HashMap<char, u8>,
16 HashMap<String, u32>,
17 Vec<String>,
18);
19
20pub fn bytes_to_unicode() -> HashMap<u8, char> {
29 let mut bs: Vec<u8> = (b'!'..=b'~').collect(); bs.extend(b'\xa1'..=b'\xac'); bs.extend(b'\xae'..=b'\xff'); let mut cs: Vec<char> = bs.iter().map(|&b| b as char).collect();
37 let mut n = 0u32; for b in 0u8..=255u8 {
39 if !bs.contains(&b) {
40 bs.push(b);
41 let cp = 0x0100u32 + n;
43 cs.push(char::from_u32(cp).unwrap_or('\u{0100}'));
44 n += 1;
45 }
46 }
47
48 bs.into_iter().zip(cs).collect()
49}
50
51#[derive(Debug, Clone)]
55pub struct ByteLevelBpeConfig {
56 pub vocab_size: usize,
58 pub min_frequency: usize,
60 pub add_prefix_space: bool,
62}
63
64impl Default for ByteLevelBpeConfig {
65 fn default() -> Self {
66 ByteLevelBpeConfig {
67 vocab_size: 50257,
68 min_frequency: 2,
69 add_prefix_space: true,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
83pub struct ByteLevelBpeTokenizer {
84 pub vocab: HashMap<String, u32>,
86 pub id_to_token: Vec<String>,
88 pub merges: Vec<(String, String)>,
90 pub byte_encoder: HashMap<u8, char>,
92 pub byte_decoder: HashMap<char, u8>,
94}
95
96impl ByteLevelBpeTokenizer {
98 fn init_base() -> InitBaseMaps {
100 let byte_encoder = bytes_to_unicode();
101 let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
102
103 let mut vocab: HashMap<String, u32> = HashMap::new();
104 let mut id_to_token: Vec<String> = Vec::new();
105 for b in 0u8..=255u8 {
107 let ch = byte_encoder[&b];
108 let tok = ch.to_string();
109 if !vocab.contains_key(&tok) {
110 let id = id_to_token.len() as u32;
111 vocab.insert(tok.clone(), id);
112 id_to_token.push(tok);
113 }
114 }
115 (byte_encoder, byte_decoder, vocab, id_to_token)
116 }
117
118 fn apply_merges(&self, chars: Vec<String>) -> Vec<String> {
121 let mut word = chars;
122 let merge_rank: HashMap<(&str, &str), usize> = self
124 .merges
125 .iter()
126 .enumerate()
127 .map(|(i, (a, b))| (a.as_str(), b.as_str()))
128 .enumerate()
130 .map(|(i, _)| (("", ""), i)) .collect();
132 let merge_rank: HashMap<(String, String), usize> = self
134 .merges
135 .iter()
136 .enumerate()
137 .map(|(i, (a, b))| ((a.clone(), b.clone()), i))
138 .collect();
139
140 loop {
141 if word.len() < 2 {
142 break;
143 }
144 let mut best_rank = usize::MAX;
146 let mut best_idx = usize::MAX;
147 for i in 0..word.len() - 1 {
148 let pair = (word[i].clone(), word[i + 1].clone());
149 if let Some(&rank) = merge_rank.get(&pair) {
150 if rank < best_rank {
151 best_rank = rank;
152 best_idx = i;
153 }
154 }
155 }
156 if best_idx == usize::MAX {
157 break; }
159 let merged = format!("{}{}", word[best_idx], word[best_idx + 1]);
161 word.remove(best_idx + 1);
162 word[best_idx] = merged;
163 }
164 word
165 }
166
167 fn byte_encode_str(&self, s: &str) -> Vec<String> {
169 s.bytes()
170 .map(|b| {
171 self.byte_encoder
172 .get(&b)
173 .copied()
174 .unwrap_or('\u{FFFD}')
175 .to_string()
176 })
177 .collect()
178 }
179}
180
181impl ByteLevelBpeTokenizer {
184 pub fn train(texts: &[&str], config: ByteLevelBpeConfig) -> Self {
190 let (byte_encoder, byte_decoder, mut vocab, mut id_to_token) = Self::init_base();
191
192 let space_char = byte_encoder
199 .get(&0x20u8)
200 .copied()
201 .unwrap_or('\u{0120}')
202 .to_string();
203
204 let mut word_freq: HashMap<Vec<String>, usize> = HashMap::new();
205 for text in texts {
206 let mut first = true;
208 for word in text.split_whitespace() {
209 let mut encoded: Vec<String> = word
211 .bytes()
212 .map(|b| {
213 byte_encoder
214 .get(&b)
215 .copied()
216 .unwrap_or('\u{FFFD}')
217 .to_string()
218 })
219 .collect();
220 if !first && config.add_prefix_space {
222 encoded.insert(0, space_char.clone());
223 }
224 first = false;
225 *word_freq.entry(encoded).or_insert(0) += 1;
226 }
227 }
228
229 let mut merges: Vec<(String, String)> = Vec::new();
230
231 while vocab.len() < config.vocab_size {
233 let mut pair_freq: HashMap<(String, String), usize> = HashMap::new();
235 for (word, &freq) in &word_freq {
236 for i in 0..word.len().saturating_sub(1) {
237 let pair = (word[i].clone(), word[i + 1].clone());
238 *pair_freq.entry(pair).or_insert(0) += freq;
239 }
240 }
241
242 let best = pair_freq
244 .iter()
245 .filter(|(_, &f)| f >= config.min_frequency)
246 .max_by_key(|(_, &f)| f);
247
248 let (left, right) = match best {
249 Some(((l, r), _)) => (l.clone(), r.clone()),
250 None => break,
251 };
252
253 merges.push((left.clone(), right.clone()));
255 let merged = format!("{}{}", left, right);
256 let new_id = id_to_token.len() as u32;
257 vocab.insert(merged.clone(), new_id);
258 id_to_token.push(merged.clone());
259
260 let updated: HashMap<Vec<String>, usize> = word_freq
262 .into_iter()
263 .map(|(word, freq)| {
264 let new_word = merge_pair_in_word(word, &left, &right);
265 (new_word, freq)
266 })
267 .collect();
268 word_freq = updated;
269 }
270
271 ByteLevelBpeTokenizer {
272 vocab,
273 id_to_token,
274 merges,
275 byte_encoder,
276 byte_decoder,
277 }
278 }
279}
280
281fn merge_pair_in_word(word: Vec<String>, left: &str, right: &str) -> Vec<String> {
283 let mut result = Vec::with_capacity(word.len());
284 let mut i = 0;
285 while i < word.len() {
286 if i + 1 < word.len() && word[i] == left && word[i + 1] == right {
287 result.push(format!("{}{}", left, right));
288 i += 2;
289 } else {
290 result.push(word[i].clone());
291 i += 1;
292 }
293 }
294 result
295}
296
297impl ByteLevelBpeTokenizer {
300 pub fn encode(&self, text: &str) -> Vec<u32> {
305 let mut ids = Vec::new();
306 let space_tok = self
312 .byte_encoder
313 .get(&0x20u8)
314 .copied()
315 .unwrap_or('\u{0120}')
316 .to_string();
317
318 let mut first = true;
319 for word in text.split_whitespace() {
320 let mut chars = self.byte_encode_str(word);
322 if !first {
324 chars.insert(0, space_tok.clone());
325 }
326 first = false;
327 let merged = self.apply_merges(chars);
328 for tok in merged {
329 if let Some(&id) = self.vocab.get(&tok) {
330 ids.push(id);
331 }
332 }
335 }
336 ids
337 }
338
339 pub fn decode(&self, ids: &[u32]) -> String {
343 let mut bytes: Vec<u8> = Vec::new();
345 for &id in ids {
346 if let Some(tok) = self.id_to_token.get(id as usize) {
347 for ch in tok.chars() {
348 if let Some(&b) = self.byte_decoder.get(&ch) {
349 bytes.push(b);
350 }
351 }
352 }
353 }
354 String::from_utf8_lossy(&bytes).into_owned()
355 }
356}
357
358impl ByteLevelBpeTokenizer {
361 pub fn save_vocab(&self, vocab_path: &str, merges_path: &str) -> Result<()> {
366 {
368 let mut f =
369 std::fs::File::create(vocab_path).map_err(|e| TextError::IoError(e.to_string()))?;
370 write!(f, "{{").map_err(|e| TextError::IoError(e.to_string()))?;
372 let mut pairs: Vec<(&String, &u32)> = self.vocab.iter().collect();
373 pairs.sort_by_key(|(_, &id)| id);
374 for (i, (tok, id)) in pairs.iter().enumerate() {
375 let escaped = escape_json_string(tok);
376 if i + 1 < pairs.len() {
377 write!(f, "\"{}\": {}, ", escaped, id)
378 .map_err(|e| TextError::IoError(e.to_string()))?;
379 } else {
380 write!(f, "\"{}\": {}", escaped, id)
381 .map_err(|e| TextError::IoError(e.to_string()))?;
382 }
383 }
384 writeln!(f, "}}").map_err(|e| TextError::IoError(e.to_string()))?;
385 }
386
387 {
389 let mut f = std::fs::File::create(merges_path)
390 .map_err(|e| TextError::IoError(e.to_string()))?;
391 writeln!(f, "#version: 0.2").map_err(|e| TextError::IoError(e.to_string()))?;
392 for (left, right) in &self.merges {
393 writeln!(f, "{} {}", left, right).map_err(|e| TextError::IoError(e.to_string()))?;
394 }
395 }
396 Ok(())
397 }
398
399 pub fn load(vocab_path: &str, merges_path: &str) -> Result<Self> {
401 let vocab_content =
403 std::fs::read_to_string(vocab_path).map_err(|e| TextError::IoError(e.to_string()))?;
404 let vocab = parse_vocab_json(&vocab_content)?;
405
406 let max_id = vocab.values().copied().max().unwrap_or(0) as usize;
408 let mut id_to_token = vec![String::new(); max_id + 1];
409 for (tok, &id) in &vocab {
410 if let Some(slot) = id_to_token.get_mut(id as usize) {
411 *slot = tok.clone();
412 }
413 }
414
415 let merges_file =
417 std::fs::File::open(merges_path).map_err(|e| TextError::IoError(e.to_string()))?;
418 let reader = BufReader::new(merges_file);
419 let mut merges = Vec::new();
420 for line in reader.lines() {
421 let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
422 let line = line.trim();
423 if line.is_empty() || line.starts_with('#') {
424 continue;
425 }
426 let parts: Vec<&str> = line.splitn(2, ' ').collect();
427 if parts.len() == 2 {
428 merges.push((parts[0].to_string(), parts[1].to_string()));
429 }
430 }
431
432 let byte_encoder = bytes_to_unicode();
433 let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
434
435 Ok(ByteLevelBpeTokenizer {
436 vocab,
437 id_to_token,
438 merges,
439 byte_encoder,
440 byte_decoder,
441 })
442 }
443
444 pub fn vocab_size(&self) -> usize {
446 self.vocab.len()
447 }
448
449 pub fn id_to_token(&self, id: u32) -> Option<&str> {
451 self.id_to_token.get(id as usize).map(|s| s.as_str())
452 }
453
454 pub fn token_to_id(&self, token: &str) -> Option<u32> {
456 self.vocab.get(token).copied()
457 }
458}
459
460fn escape_json_string(s: &str) -> String {
464 let mut out = String::with_capacity(s.len());
465 for ch in s.chars() {
466 match ch {
467 '"' => out.push_str("\\\""),
468 '\\' => out.push_str("\\\\"),
469 '\n' => out.push_str("\\n"),
470 '\r' => out.push_str("\\r"),
471 '\t' => out.push_str("\\t"),
472 c if (c as u32) < 0x20 => {
473 out.push_str(&format!("\\u{:04x}", c as u32));
474 }
475 c => out.push(c),
476 }
477 }
478 out
479}
480
481fn parse_vocab_json(s: &str) -> Result<HashMap<String, u32>> {
486 let s = s.trim();
487 let inner = s
488 .strip_prefix('{')
489 .and_then(|s| s.strip_suffix('}'))
490 .ok_or_else(|| TextError::IoError("Invalid vocab JSON: missing braces".to_string()))?;
491
492 let mut vocab = HashMap::new();
493 let chars: Vec<char> = inner.chars().collect();
495 let n = chars.len();
496 let mut i = 0;
497 let mut start = 0;
498
499 while i <= n {
500 let at_end = i == n;
501
502 if at_end {
503 let entry: String = chars[start..i].iter().collect();
505 let entry = entry.trim();
506 if !entry.is_empty() {
507 parse_vocab_entry(entry, &mut vocab)?;
508 }
509 break;
510 }
511
512 let ch = chars[i];
513
514 if ch == '"' {
515 i += 1;
517 while i < n {
518 let sc = chars[i];
519 i += 1;
520 if sc == '\\' {
521 i += 1;
523 } else if sc == '"' {
524 break;
525 }
526 }
527 continue;
529 }
530
531 if ch == ',' {
532 let entry: String = chars[start..i].iter().collect();
533 let entry = entry.trim();
534 if !entry.is_empty() {
535 parse_vocab_entry(entry, &mut vocab)?;
536 }
537 start = i + 1;
538 }
539
540 i += 1;
541 }
542
543 Ok(vocab)
544}
545
546fn parse_vocab_entry(entry: &str, vocab: &mut HashMap<String, u32>) -> Result<()> {
547 let colon_pos = find_colon_outside_string(entry)
549 .ok_or_else(|| TextError::IoError(format!("Invalid vocab entry (no colon): {}", entry)))?;
550 let key_part = entry[..colon_pos].trim();
551 let val_part = entry[colon_pos + 1..].trim();
552
553 let key = key_part
554 .strip_prefix('"')
555 .and_then(|s| s.strip_suffix('"'))
556 .map(unescape_json_string)
557 .ok_or_else(|| TextError::IoError(format!("Invalid vocab key: {}", key_part)))?;
558
559 let id: u32 = val_part
560 .parse()
561 .map_err(|_| TextError::IoError(format!("Invalid vocab id: {}", val_part)))?;
562
563 vocab.insert(key, id);
564 Ok(())
565}
566
567fn find_colon_outside_string(s: &str) -> Option<usize> {
568 let mut in_str = false;
569 let mut escaped = false;
570 for (i, ch) in s.char_indices() {
571 if escaped {
572 escaped = false;
573 continue;
574 }
575 if ch == '\\' && in_str {
576 escaped = true;
577 continue;
578 }
579 if ch == '"' {
580 in_str = !in_str;
581 continue;
582 }
583 if ch == ':' && !in_str {
584 return Some(i);
585 }
586 }
587 None
588}
589
590fn unescape_json_string(s: &str) -> String {
591 let mut out = String::with_capacity(s.len());
592 let mut chars = s.chars().peekable();
593 while let Some(ch) = chars.next() {
594 if ch == '\\' {
595 match chars.next() {
596 Some('"') => out.push('"'),
597 Some('\\') => out.push('\\'),
598 Some('/') => out.push('/'),
599 Some('n') => out.push('\n'),
600 Some('r') => out.push('\r'),
601 Some('t') => out.push('\t'),
602 Some('u') => {
603 let hex: String = chars.by_ref().take(4).collect();
604 if let Ok(n) = u32::from_str_radix(&hex, 16) {
605 if let Some(c) = char::from_u32(n) {
606 out.push(c);
607 }
608 }
609 }
610 Some(c) => out.push(c),
611 None => {}
612 }
613 } else {
614 out.push(ch);
615 }
616 }
617 out
618}
619
620#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn test_bytes_to_unicode_count() {
628 let map = bytes_to_unicode();
629 assert_eq!(map.len(), 256, "should have exactly 256 entries");
630 }
631
632 #[test]
633 fn test_bytes_to_unicode_bijective() {
634 let map = bytes_to_unicode();
635 let mut chars: Vec<char> = map.values().copied().collect();
636 chars.sort();
637 chars.dedup();
638 assert_eq!(
639 chars.len(),
640 256,
641 "all unicode chars must be distinct (bijection)"
642 );
643 }
644
645 #[test]
646 fn test_bytes_to_unicode_ascii_identity() {
647 let map = bytes_to_unicode();
648 for b in b'!'..=b'~' {
650 let ch = map[&b];
651 assert_eq!(
652 ch as u32, b as u32,
653 "byte {} should map to itself, got {}",
654 b, ch as u32
655 );
656 }
657 }
658
659 #[test]
660 fn test_train_vocab_size() {
661 let texts = [
662 "the quick brown fox jumps over the lazy dog",
663 "hello world hello rust hello tokenizer",
664 "byte level bpe tokenizer training test data for vocabulary",
665 "more text data to train the byte level bpe model properly",
666 ];
667 let config = ByteLevelBpeConfig {
668 vocab_size: 300,
669 min_frequency: 1,
670 add_prefix_space: true,
671 };
672 let tok = ByteLevelBpeTokenizer::train(&texts, config);
673 assert!(
674 tok.vocab_size() <= 300,
675 "vocab size should not exceed requested"
676 );
677 assert!(
678 tok.vocab_size() >= 256,
679 "should have at least base 256 tokens"
680 );
681 }
682
683 #[test]
684 fn test_encode_decode_roundtrip() {
685 let texts = [
686 "hello world",
687 "the quick brown fox",
688 "rust programming language",
689 "byte level encoding test",
690 ];
691 let config = ByteLevelBpeConfig {
692 vocab_size: 500,
693 min_frequency: 1,
694 add_prefix_space: true,
695 };
696 let tok = ByteLevelBpeTokenizer::train(&texts, config);
697 let input = "hello world";
698 let ids = tok.encode(input);
699 let decoded = tok.decode(&ids);
700 assert_eq!(decoded, input, "encode/decode roundtrip should be lossless");
701 }
702
703 #[test]
704 fn test_gword_prefix() {
705 let texts = ["hello world test"];
707 let config = ByteLevelBpeConfig {
708 vocab_size: 300,
709 min_frequency: 1,
710 add_prefix_space: true,
711 };
712 let tok = ByteLevelBpeTokenizer::train(&texts, config);
713 let has_g_prefix = tok.vocab.keys().any(|k| k.starts_with('\u{0120}'));
715 assert!(has_g_prefix, "vocabulary should contain Ġ-prefixed tokens");
716 }
717
718 #[test]
719 fn test_hello_token() {
720 let texts = ["hello world hello hello hello"];
721 let config = ByteLevelBpeConfig {
722 vocab_size: 300,
723 min_frequency: 1,
724 add_prefix_space: false,
725 };
726 let tok = ByteLevelBpeTokenizer::train(&texts, config);
727 assert!(
730 tok.vocab.contains_key("hello"),
731 "hello should be in vocabulary after training on repeated hello"
732 );
733 }
734
735 #[test]
736 fn test_save_load_roundtrip() {
737 let texts = [
738 "hello world",
739 "test tokenizer save load",
740 "byte level bpe tokenizer",
741 ];
742 let config = ByteLevelBpeConfig {
743 vocab_size: 350,
744 min_frequency: 1,
745 add_prefix_space: true,
746 };
747 let tok = ByteLevelBpeTokenizer::train(&texts, config);
748
749 let dir = std::env::temp_dir();
750 let vocab_path = dir
751 .join("test_bpe_vocab.json")
752 .to_string_lossy()
753 .into_owned();
754 let merges_path = dir
755 .join("test_bpe_merges.txt")
756 .to_string_lossy()
757 .into_owned();
758
759 tok.save_vocab(&vocab_path, &merges_path)
760 .expect("save failed");
761 let loaded = ByteLevelBpeTokenizer::load(&vocab_path, &merges_path).expect("load failed");
762
763 assert_eq!(tok.vocab_size(), loaded.vocab_size());
764 assert_eq!(tok.merges.len(), loaded.merges.len());
765 }
766}