1use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word};
2use crate::tokenizer::{Model, Result, Token};
3use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH};
4use crate::utils::iter::ResultShunt;
5use serde_json::Value;
6use std::borrow::Cow;
7use std::{
8 collections::HashMap,
9 fs::File,
10 io::prelude::*,
11 io::{BufRead, BufReader},
12 path::{Path, PathBuf},
13};
14
15pub type Vocab = HashMap<String, u32>;
16type VocabR = HashMap<u32, String>;
17pub type MergeMap = HashMap<Pair, (u32, u32)>;
18pub type Merges = Vec<(String, String)>;
19
20struct Config {
21 files: Option<(String, String)>,
22 vocab: Vocab,
23 merges: Merges,
24 cache_capacity: usize,
25 dropout: Option<f32>,
26 unk_token: Option<String>,
27 continuing_subword_prefix: Option<String>,
28 end_of_word_suffix: Option<String>,
29 fuse_unk: bool,
30 byte_fallback: bool,
31 ignore_merges: bool,
32}
33
34pub struct BpeBuilder {
36 config: Config,
37}
38
39impl Default for BpeBuilder {
40 fn default() -> Self {
41 Self {
42 config: Config {
43 files: None,
44 vocab: HashMap::new(),
45 merges: vec![],
46 cache_capacity: DEFAULT_CACHE_CAPACITY,
47 dropout: None,
48 unk_token: None,
49 continuing_subword_prefix: None,
50 end_of_word_suffix: None,
51 fuse_unk: false,
52 byte_fallback: false,
53 ignore_merges: false,
54 },
55 }
56 }
57}
58
59impl BpeBuilder {
60 pub fn new() -> Self {
62 Self::default()
63 }
64
65 #[must_use]
67 pub fn files(mut self, vocab: String, merges: String) -> Self {
68 self.config.files = Some((vocab, merges));
69 self
70 }
71
72 #[must_use]
74 pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self {
75 self.config.vocab = vocab;
76 self.config.merges = merges;
77 self
78 }
79
80 #[must_use]
82 pub fn cache_capacity(mut self, capacity: usize) -> Self {
83 self.config.cache_capacity = capacity;
84 self
85 }
86
87 #[must_use]
89 pub fn dropout(mut self, dropout: f32) -> Self {
90 self.config.dropout = Some(dropout);
91 self
92 }
93
94 #[must_use]
96 pub fn unk_token(mut self, unk_token: String) -> Self {
97 self.config.unk_token = Some(unk_token);
98 self
99 }
100
101 #[must_use]
103 pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
104 self.config.continuing_subword_prefix = Some(prefix);
105 self
106 }
107
108 #[must_use]
110 pub fn end_of_word_suffix(mut self, prefix: String) -> Self {
111 self.config.end_of_word_suffix = Some(prefix);
112 self
113 }
114
115 #[must_use]
117 pub fn fuse_unk(mut self, fuse_unk: bool) -> Self {
118 self.config.fuse_unk = fuse_unk;
119 self
120 }
121
122 #[must_use]
124 pub fn byte_fallback(mut self, byte_fallback: bool) -> Self {
125 self.config.byte_fallback = byte_fallback;
126 self
127 }
128 #[must_use]
130 pub fn ignore_merges(mut self, ignore_merges: bool) -> Self {
131 self.config.ignore_merges = ignore_merges;
132 self
133 }
134
135 pub fn build(mut self) -> Result<BPE> {
137 if let Some(p) = self.config.dropout {
139 if !(0.0..=1.0).contains(&p) {
140 return Err(Error::InvalidDropout.into());
141 }
142 }
143
144 if let Some((vocab, merges)) = self.config.files {
146 let (v, m) = BPE::read_file(&vocab, &merges)?;
147 self.config.vocab = v;
148 self.config.merges = m;
149 }
150
151 let vocab_r = self
152 .config
153 .vocab
154 .iter()
155 .map(|(key, val)| (*val, key.to_owned()))
156 .collect();
157 let cache = match self.config.cache_capacity {
158 0 => None,
159 capacity => Some(Cache::new(capacity)),
160 };
161
162 let vocab = self.config.vocab;
163 let prefix_len = if let Some(prefix) = &self.config.continuing_subword_prefix {
164 prefix.len()
165 } else {
166 0
167 };
168 let merge_map: MergeMap = self
169 .config
170 .merges
171 .into_iter()
172 .enumerate()
173 .map(|(i, (a, b))| -> Result<(Pair, (u32, u32))> {
174 let a_id = vocab
175 .get(&a)
176 .ok_or_else(|| Error::MergeTokenOutOfVocabulary(a.to_owned()))?;
177 let b_id = vocab
178 .get(&b)
179 .ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_owned()))?;
180 let new_token = format!("{}{}", a, &b[prefix_len..]);
181 let new_id = vocab
182 .get(&new_token)
183 .ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
184 Ok(((*a_id, *b_id), (i as u32, *new_id)))
185 })
186 .collect::<Result<MergeMap>>()?;
187
188 Ok(BPE {
191 vocab,
192 vocab_r,
193 merges: merge_map,
194 cache,
195 dropout: self.config.dropout,
196 unk_token: self.config.unk_token,
197 continuing_subword_prefix: self.config.continuing_subword_prefix,
198 end_of_word_suffix: self.config.end_of_word_suffix,
199 fuse_unk: self.config.fuse_unk,
200 byte_fallback: self.config.byte_fallback,
201 ignore_merges: self.config.ignore_merges,
202 })
203 }
204}
205
206#[derive(PartialEq)]
208pub struct BPE {
209 pub(crate) vocab: Vocab,
211 pub(crate) vocab_r: VocabR,
213 pub(crate) merges: MergeMap,
215 cache: Option<Cache<String, Word>>,
217 pub dropout: Option<f32>,
220 pub unk_token: Option<String>,
222 pub continuing_subword_prefix: Option<String>,
224 pub end_of_word_suffix: Option<String>,
226 pub fuse_unk: bool,
228 pub byte_fallback: bool,
231 pub ignore_merges: bool,
233}
234
235impl std::fmt::Debug for BPE {
236 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
237 fmt.debug_struct("BPE")
238 .field("dropout", &self.dropout)
239 .field("unk_token", &self.unk_token)
240 .field("continuing_subword_prefix", &self.continuing_subword_prefix)
241 .field("end_of_word_suffix", &self.end_of_word_suffix)
242 .field("fuse_unk", &self.fuse_unk)
243 .field("byte_fallback", &self.byte_fallback)
244 .field("vocab", &self.vocab.len())
245 .field("merges", &self.merges.len())
246 .field("ignore_merges", &self.ignore_merges)
247 .finish()
248 }
249}
250
251impl Default for BPE {
252 fn default() -> Self {
253 Self::builder().build().unwrap()
254 }
255}
256
257impl Clone for BPE {
258 fn clone(&self) -> Self {
261 let fresh_cache = self.cache.as_ref().map(|cache| cache.fresh());
262 Self {
263 vocab: self.vocab.clone(),
264 vocab_r: self.vocab_r.clone(),
265 merges: self.merges.clone(),
266 cache: fresh_cache,
267 dropout: self.dropout,
268 unk_token: self.unk_token.clone(),
269 continuing_subword_prefix: self.continuing_subword_prefix.clone(),
270 end_of_word_suffix: self.end_of_word_suffix.clone(),
271 fuse_unk: self.fuse_unk,
272 byte_fallback: self.byte_fallback,
273 ignore_merges: self.ignore_merges,
274 }
275 }
276}
277
278pub(crate) fn convert_merges_to_hashmap<I: Iterator<Item = String>>(
281 iter: I,
282 _vocab: &Vocab,
283) -> Result<Merges> {
284 let mut merges = vec![];
285
286 let lines = iter.filter(|l| !l.starts_with("#version"));
287 for (rank, line) in lines.enumerate() {
288 let parts = line.split(' ').collect::<Vec<_>>();
289 if parts.len() != 2 {
290 return Err(Error::BadMerges(rank + 1).into());
291 }
292
293 merges.push((parts[0].to_string(), parts[1].to_string()));
294 }
295
296 Ok(merges)
297}
298
299impl BPE {
300 pub fn builder() -> BpeBuilder {
302 BpeBuilder::new()
303 }
304
305 pub fn new(vocab: Vocab, merges: Merges) -> Self {
307 Self::builder()
308 .vocab_and_merges(vocab, merges)
309 .build()
310 .unwrap()
311 }
312
313 pub fn from_file(vocab: &str, merges: &str) -> BpeBuilder {
315 Self::builder().files(vocab.to_owned(), merges.to_owned())
316 }
317
318 pub fn read_file(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> {
320 let vocab_file = File::open(vocab)?;
322 let mut vocab_file = BufReader::new(vocab_file);
323
324 let mut buffer = String::new();
325 vocab_file.read_to_string(&mut buffer)?;
326 let json: Value = serde_json::from_str(&buffer)?;
327 let mut vocab = HashMap::new();
328 match json {
329 Value::Object(m) => {
330 for (token, id) in m {
331 if let Value::Number(id) = id {
332 let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
333 vocab.insert(token, id);
334 }
335 }
336 }
337 _ => return Err(Box::new(Error::BadVocabulary)),
338 };
339
340 let merge_file = File::open(merges)?;
342 let merge_file = BufReader::new(merge_file);
343 let merges = ResultShunt::process(merge_file.lines(), |iter| {
344 convert_merges_to_hashmap(iter, &vocab)
345 })??;
346
347 Ok((vocab, merges))
348 }
349
350 pub fn clear_cache(&self) {
352 if let Some(ref cache) = self.cache {
353 cache.clear()
354 }
355 }
356
357 pub fn resize_cache(&mut self, capacity: usize) {
359 if let Some(ref mut cache) = self.cache {
360 cache.resize(capacity);
361 }
362 }
363
364 pub fn get_vocab(&self) -> Vocab {
365 self.vocab.clone()
366 }
367
368 pub fn get_unk_token(&self) -> &Option<String> {
369 &self.unk_token
370 }
371
372 pub fn get_continuing_subword_prefix(&self) -> &Option<String> {
373 &self.continuing_subword_prefix
374 }
375
376 fn merge_word(&self, w: &str) -> Result<Word> {
377 let mut indices = w.char_indices().map(|(idx, _)| idx).peekable();
378 let mut word = Word::with_capacity(w.len());
379 let mut unk: Option<(u32, usize)> = None;
380 while let Some(i) = indices.next() {
381 let end = indices.peek();
382 let is_first = i == 0;
383 let is_last = end.is_none();
384
385 let mut s = if let Some(e) = end {
386 Cow::Borrowed(&w[i..*e])
387 } else {
388 Cow::Borrowed(&w[i..])
389 };
390 let byte_len = s.len();
391
392 if !is_first {
394 if let Some(ref prefix) = self.continuing_subword_prefix {
395 s = format!("{prefix}{s}").into()
396 }
397 }
398 if is_last {
400 if let Some(ref suffix) = self.end_of_word_suffix {
401 s = format!("{s}{suffix}").into()
402 }
403 }
404
405 if let Some(id) = self.vocab.get(s.as_ref()) {
406 if let Some((unk_id, unk_len)) = unk {
407 word.add(unk_id, unk_len);
408 unk = None;
409 }
410 word.add(*id, byte_len);
411 } else {
412 if self.byte_fallback {
413 let tokens: Option<Vec<_>> = s
414 .bytes()
415 .map(|b| -> Option<&u32> {
416 let code = format!("<{b:#04X}>");
417
418 self.vocab.get(&code)
419 })
420 .collect();
421 if let Some(tokens) = tokens {
422 for t in tokens {
423 word.add(*t, 1);
424 }
425 continue;
426 }
427 }
428 if let Some(unk_token) = &self.unk_token {
429 unk = match (unk, self.fuse_unk) {
430 (Some((unk_id, unk_len)), true) => {
431 Some((unk_id, unk_len + byte_len))
433 }
434 (Some((unk_id, unk_len)), false) => {
435 word.add(unk_id, unk_len);
437 Some((
438 *self.vocab.get(unk_token).ok_or_else(|| {
439 Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
440 })?,
441 byte_len,
442 ))
443 }
444 _ => Some((
445 *self.vocab.get(unk_token).ok_or_else(|| {
446 Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
447 })?,
448 byte_len,
449 )),
450 };
451 }
452 }
453 }
454 if let Some((unk_id, unk_len)) = unk {
455 word.add(unk_id, unk_len);
456 }
457
458 word.merge_all(&self.merges, self.dropout);
459
460 Ok(word)
461 }
462
463 fn word_to_tokens<'a, 'b: 'a>(&'a self, word: &'b Word) -> impl Iterator<Item = Token> + 'a {
464 word.get_chars_iter()
465 .zip(word.get_offsets_iter())
466 .map(move |(id, offsets)| Token::new(id, self.vocab_r[&id].clone(), offsets))
467 }
468
469 fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
470 if self.ignore_merges {
471 if let Some(id) = self.vocab.get(sequence) {
472 return Ok(vec![Token::new(
473 *id,
474 sequence.to_string().clone(),
475 (0, sequence.len()),
476 )]);
477 }
478 }
479 if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
480 return Ok(self.word_to_tokens(hit).collect());
481 }
482 let word = self.merge_word(sequence)?;
483 let ret = self.word_to_tokens(&word).collect();
484 if let Some(ref cache) = self.cache {
485 if sequence.len() < MAX_LENGTH {
486 cache.set(sequence.to_owned(), word);
487 }
488 }
489 Ok(ret)
490 }
491}
492
493impl Model for BPE {
494 type Trainer = BpeTrainer;
495
496 fn get_vocab(&self) -> HashMap<String, u32> {
497 self.vocab.clone()
498 }
499
500 fn get_vocab_size(&self) -> usize {
501 self.vocab.len()
502 }
503
504 fn tokenize(&self, sequence: &str) -> Result<Vec<Token>> {
505 if sequence.is_empty() {
506 return Ok(vec![]);
507 }
508
509 if self.dropout.is_none() || self.dropout == Some(0.0) {
510 self.tokenize_with_cache(sequence)
511 } else {
512 let word = self.merge_word(sequence)?;
513 Ok(self.word_to_tokens(&word).collect())
514 }
515 }
516
517 fn token_to_id(&self, token: &str) -> Option<u32> {
518 self.vocab.get(token).copied()
519 }
520
521 fn id_to_token(&self, id: u32) -> Option<String> {
522 self.vocab_r.get(&id).cloned()
523 }
524
525 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
526 let vocab_file_name = match name {
527 Some(name) => format!("{name}-vocab.json"),
528 None => "vocab.json".to_string(),
529 };
530
531 let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
533 .iter()
534 .collect();
535 let mut vocab_file = File::create(&vocab_path)?;
536 let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
537 let serialized = serde_json::to_string(&order_vocab_iter)?;
538 vocab_file.write_all(serialized.as_bytes())?;
539
540 let merges_file_name = match name {
542 Some(name) => format!("{name}-merges.txt"),
543 None => "merges.txt".to_string(),
544 };
545
546 let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())]
547 .iter()
548 .collect();
549 let mut merges_file = File::create(&merges_path)?;
550 let mut merges: Vec<(&Pair, &u32)> = self
551 .merges
552 .iter()
553 .map(|(pair, (rank, _))| (pair, rank))
554 .collect();
555 merges.sort_unstable_by_key(|k| *k.1);
556 merges_file.write_all(b"#version: 0.2\n")?;
557 merges_file.write_all(
558 &merges
559 .into_iter()
560 .flat_map(|(pair, _)| {
561 format!("{} {}\n", self.vocab_r[&pair.0], self.vocab_r[&pair.1]).into_bytes()
562 })
563 .collect::<Vec<_>>()[..],
564 )?;
565
566 Ok(vec![vocab_path, merges_path])
567 }
568
569 fn get_trainer(&self) -> BpeTrainer {
570 BpeTrainer::default()
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use tempfile::NamedTempFile;
578
579 #[test]
580 fn test_ordered_vocab_iter() {
581 let vocab_r: VocabR = [
582 (0, "a".into()),
583 (1, "b".into()),
584 (2, "c".into()),
585 (3, "ab".into()),
586 ]
587 .iter()
588 .cloned()
589 .collect();
590 let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
591 let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
592 assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
593 }
594
595 #[test]
596 fn test_unk_not_fused() {
597 let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
598 .iter()
599 .cloned()
600 .collect();
601 let bpe = BpeBuilder::default()
602 .vocab_and_merges(vocab, vec![])
603 .unk_token("<unk>".to_string())
604 .build()
605 .unwrap();
606 let tokens = bpe.tokenize("c").unwrap();
607 assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
608
609 let tokens = bpe.tokenize("cc").unwrap();
610 assert_eq!(
611 tokens,
612 vec![
613 Token::new(0u32, "<unk>".into(), (0, 1)),
614 Token::new(0u32, "<unk>".into(), (1, 2)),
615 ]
616 );
617
618 let tokens = bpe.tokenize("accb").unwrap();
619 assert_eq!(
620 tokens,
621 vec![
622 Token::new(1u32, "a".into(), (0, 1)),
623 Token::new(0u32, "<unk>".into(), (1, 2)),
624 Token::new(0u32, "<unk>".into(), (2, 3)),
625 Token::new(2u32, "b".into(), (3, 4)),
626 ]
627 );
628 }
629 #[test]
630 fn test_unk_get_fused() {
631 let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
632 .iter()
633 .cloned()
634 .collect();
635 let bpe = BpeBuilder::default()
636 .vocab_and_merges(vocab, vec![])
637 .unk_token("<unk>".to_string())
638 .fuse_unk(true)
639 .build()
640 .unwrap();
641 let tokens = bpe.tokenize("c").unwrap();
642 assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
643
644 let tokens = bpe.tokenize("cc").unwrap();
645 assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 2)),]);
646
647 let tokens = bpe.tokenize("accb").unwrap();
648 assert_eq!(
649 tokens,
650 vec![
651 Token::new(1u32, "a".into(), (0, 1)),
652 Token::new(0u32, "<unk>".into(), (1, 3)),
653 Token::new(2u32, "b".into(), (3, 4)),
654 ]
655 );
656 }
657
658 #[test]
659 fn test_tokenize_with_and_without_dropout() {
664 let vocab: Vocab = [
665 ("u".into(), 0),
666 ("n".into(), 1),
667 ("r".into(), 2),
668 ("e".into(), 3),
669 ("l".into(), 4),
670 ("a".into(), 5),
671 ("t".into(), 6),
672 ("d".into(), 7),
673 ("re".into(), 8),
674 ("at".into(), 9),
675 ("ed".into(), 10),
676 ("un".into(), 11),
677 ("ated".into(), 12),
678 ("rel".into(), 13),
679 ("related".into(), 14),
680 ("unrelated".into(), 15),
681 ]
682 .iter()
683 .cloned()
684 .collect();
685 let merges: Merges = vec![
686 ("r".to_string(), "e".to_string()),
687 ("a".to_string(), "t".to_string()),
688 ("e".to_string(), "d".to_string()),
689 ("u".to_string(), "n".to_string()),
690 ("at".to_string(), "ed".to_string()),
691 ("re".to_string(), "l".to_string()),
692 ("rel".to_string(), "ated".to_string()),
693 ("un".to_string(), "related".to_string()),
694 ];
695 let mut bpe = BPE::new(vocab, merges);
696
697 let tokens = bpe.tokenize("unrelated").unwrap();
699 assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
700
701 bpe.dropout = Some(0.0);
703 let tokens = bpe.tokenize("unrelated").unwrap();
704 assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
705
706 bpe.dropout = Some(1.0);
708 let tokens = bpe.tokenize("unrelated").unwrap();
709 assert_eq!(
710 tokens,
711 vec![
712 Token::new(0u32, "u".into(), (0, 1)),
713 Token::new(1u32, "n".into(), (1, 2)),
714 Token::new(2u32, "r".into(), (2, 3)),
715 Token::new(3u32, "e".into(), (3, 4)),
716 Token::new(4u32, "l".into(), (4, 5)),
717 Token::new(5u32, "a".into(), (5, 6)),
718 Token::new(6u32, "t".into(), (6, 7)),
719 Token::new(3u32, "e".into(), (7, 8)),
720 Token::new(7u32, "d".into(), (8, 9)),
721 ]
722 );
723
724 bpe.dropout = Some(0.5);
726 let tokens = bpe.tokenize("unrelated").unwrap();
727 assert!(!tokens.is_empty() && tokens.len() <= 9);
728 }
729
730 #[test]
731 fn test_bpe_from_file() {
733 let mut vocab_file = NamedTempFile::new().unwrap();
735 vocab_file
736 .write_all(b"{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}")
737 .unwrap();
738
739 let mut merges_file = NamedTempFile::new().unwrap();
741 merges_file.write_all(b"#version: 0.2\na b").unwrap();
742
743 let builder = BPE::from_file(
745 vocab_file.path().to_str().unwrap(),
746 merges_file.path().to_str().unwrap(),
747 );
748 let bpe = builder.build().unwrap();
749
750 assert_eq!(bpe.merges.get(&(0, 1)).unwrap(), &(0u32, 3u32));
752
753 assert_eq!(bpe.vocab.get("a").unwrap(), &0u32);
755 assert_eq!(bpe.vocab.get("b").unwrap(), &1u32);
756 assert_eq!(bpe.vocab.get("c").unwrap(), &2u32);
757 assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
758 }
759
760 #[test]
761 fn test_bpe_with_dropout_0() {
763 let bpe = BPE::builder().dropout(0.0).build().unwrap();
764 assert_eq!(bpe.dropout, Some(0.0));
765 }
766
767 #[test]
768 fn test_bpe_with_continuing_subword_prefix() {
770 let vocab: Vocab = vec![
771 ("a".to_string(), 0),
772 ("##b".to_string(), 1),
773 ("##c".to_string(), 2),
774 ("ab".to_string(), 3),
775 ("abc".to_string(), 4),
776 ]
777 .into_iter()
778 .collect();
779
780 let merges = vec![
781 ("a".to_string(), "##b".to_string()),
782 ("ab".to_string(), "##c".to_string()),
783 ];
784
785 let bpe = BPE::builder()
786 .vocab_and_merges(vocab, merges)
787 .unk_token("[UNK]".to_string())
788 .continuing_subword_prefix("##".to_string())
789 .build()
790 .unwrap();
791
792 let res = bpe.tokenize("ab");
793 assert_eq!(
794 res.unwrap(),
795 vec![Token {
796 id: 3,
797 value: "ab".to_string(),
798 offsets: (0, 2)
799 }]
800 );
801 let res = bpe.tokenize("abc");
802 assert_eq!(
803 res.unwrap(),
804 vec![Token {
805 id: 4,
806 value: "abc".to_string(),
807 offsets: (0, 3)
808 }]
809 );
810 }
811
812 #[test]
813 fn test_bpe_from_file_merge_token_oov() {
815 let mut vocab_file = NamedTempFile::new().unwrap();
817 vocab_file
818 .write_all(b"{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}")
819 .unwrap();
820
821 let mut merges_file = NamedTempFile::new().unwrap();
823 merges_file.write_all(b"#version: 0.2\na b\na d").unwrap();
824
825 match BPE::from_file(
827 vocab_file.path().to_str().unwrap(),
828 merges_file.path().to_str().unwrap(),
829 )
830 .build()
831 {
832 Ok(_) => unreachable!(),
833 Err(err) => match err.downcast_ref::<Error>() {
834 Some(Error::MergeTokenOutOfVocabulary(token)) => {
835 assert_eq!(*token, String::from("d"))
836 }
837 _ => unreachable!(),
838 },
839 }
840 }
841
842 #[test]
843 fn test_bpe_from_file_bad_merges() {
846 let mut vocab_file = NamedTempFile::new().unwrap();
848 vocab_file
849 .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
850 .unwrap();
851
852 let mut merges_file = NamedTempFile::new().unwrap();
854 merges_file.write_all(b"#version: 0.2\na b\nc").unwrap();
855
856 match BPE::from_file(
858 vocab_file.path().to_str().unwrap(),
859 merges_file.path().to_str().unwrap(),
860 )
861 .build()
862 {
863 Ok(_) => unreachable!(),
864 Err(err) => match err.downcast_ref::<Error>() {
865 Some(Error::BadMerges(line)) => assert_eq!(*line, 2),
866 _ => unreachable!(),
867 },
868 }
869 }
870
871 #[test]
872 fn test_bpe_byte_fallback() {
873 let vocab: Vocab = [("<unk>".into(), 0), ("<0x61>".into(), 1)]
875 .iter()
876 .cloned()
877 .collect();
878 let bpe = BpeBuilder::default()
879 .vocab_and_merges(vocab, vec![])
880 .unk_token("<unk>".to_string())
881 .byte_fallback(true)
882 .build()
883 .unwrap();
884 let tokens = bpe.tokenize("c").unwrap();
885 assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
886
887 let tokens = bpe.tokenize("a").unwrap();
888 assert_eq!(tokens, vec![Token::new(1u32, "<0x61>".into(), (0, 1)),]);
889 }
890
891 #[test]
892 fn test_bpe_byte_fallback_newline() {
893 let vocab: Vocab = [("<unk>".into(), 0), ("<0x0A>".into(), 1)]
895 .iter()
896 .cloned()
897 .collect();
898 let bpe = BpeBuilder::default()
899 .vocab_and_merges(vocab, vec![])
900 .unk_token("<unk>".to_string())
901 .byte_fallback(true)
902 .build()
903 .unwrap();
904 let tokens = bpe.tokenize("\n").unwrap();
905 assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
906 }
907
908 #[test]
909 fn test_ignore_merges() {
910 let vocab: Vocab = [
912 (".:.:".into(), 0),
913 ("Ġbelirtilen".into(), 1),
914 (".".into(), 2),
915 (":".into(), 3),
916 ("bel".into(), 4),
917 ("irtilen".into(), 5),
918 ("Ġ".into(), 6),
919 (".:".into(), 7),
920 ("belirtilen".into(), 8),
921 (".:.".into(), 9),
922 ("be".into(), 10),
923 ("l".into(), 11),
924 ("ir".into(), 12),
925 ("ti".into(), 13),
926 ("en".into(), 14),
927 ("irtil".into(), 15),
928 ("irti".into(), 16),
929 ("i".into(), 17),
930 ("r".into(), 18),
931 ("t".into(), 19),
932 ("b".into(), 20),
933 ("e".into(), 21),
934 ("n".into(), 22),
935 ]
936 .iter()
937 .cloned()
938 .collect();
939 let mut bpe = BpeBuilder::default()
940 .vocab_and_merges(
941 vocab,
942 vec![
943 (".".into(), ":".into()),
944 ("b".into(), "e".into()),
945 ("be".into(), "l".into()),
946 ("i".into(), "r".into()),
947 ("t".into(), "i".into()),
948 ("ir".into(), "ti".into()),
949 ("e".into(), "n".into()),
950 ("irti".into(), "l".into()),
951 ],
952 )
953 .ignore_merges(true)
954 .build()
955 .unwrap();
956 let tokens = bpe.tokenize(".:.:").unwrap();
957 assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 4))]);
958
959 let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
960 assert_eq!(
961 tokens,
962 vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 12))]
963 );
964
965 bpe.ignore_merges = false;
966
967 let tokens = bpe.tokenize(".:.:").unwrap();
968 assert_eq!(
969 tokens,
970 vec![
971 Token::new(7u32, ".:".into(), (0, 2)),
972 Token::new(7u32, ".:".into(), (2, 4))
973 ]
974 );
975
976 let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
977 assert_eq!(
978 tokens,
979 vec![
980 Token {
981 id: 6,
982 value: "Ġ".into(),
983 offsets: (0, 2)
984 },
985 Token {
986 id: 4,
987 value: "bel".into(),
988 offsets: (2, 5)
989 },
990 Token {
991 id: 15,
992 value: "irtil".into(),
993 offsets: (5, 10)
994 },
995 Token {
996 id: 14,
997 value: "en".into(),
998 offsets: (10, 12)
999 }
1000 ]
1001 )
1002 }
1003}