rust_transformers/preprocessing/tokenizer/
openai_gpt_tokenizer.rs1use crate::OpenAiGptVocab;
15use crate::preprocessing::vocab::base_vocab::Vocab;
16use crate::preprocessing::tokenizer::base_tokenizer::{Tokenizer, BaseTokenizer};
17use std::collections::HashMap;
18use crate::preprocessing::tokenizer::tokenization_utils::{split_on_special_tokens, openai_gpt_bpe};
19use std::rc::Rc;
20use std::cell::RefCell;
21use crate::preprocessing::vocab::bpe_vocab::BpePairVocab;
22use std::sync::Arc;
23
24pub struct OpenAiGptTokenizer {
25 vocab: Arc<OpenAiGptVocab>,
26 base_tokenizer: BaseTokenizer<OpenAiGptVocab>,
27 bpe_ranks: Rc<BpePairVocab>,
28 cache: RefCell<HashMap<String, Vec<String>>>,
29}
30
31impl OpenAiGptTokenizer {
32 pub fn from_file(vocab_path: &str, merges_path: &str) -> OpenAiGptTokenizer {
33 let vocab = Arc::new(OpenAiGptVocab::from_file(vocab_path));
34 let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
35 let bpe_ranks = Rc::new(BpePairVocab::from_file(merges_path));
36 let cache = RefCell::new(HashMap::new());
37 OpenAiGptTokenizer { vocab, base_tokenizer, bpe_ranks, cache}
38 }
39
40 pub fn from_existing_vocab_and_merges(vocab: Arc<OpenAiGptVocab>, merges: Rc<BpePairVocab>) -> OpenAiGptTokenizer {
41 let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
42 let cache = RefCell::new(HashMap::new());
43 OpenAiGptTokenizer { vocab, base_tokenizer, bpe_ranks: merges, cache}
44 }
45}
46
47impl Tokenizer<OpenAiGptVocab> for OpenAiGptTokenizer {
48 fn vocab(&self) -> &OpenAiGptVocab {
49 &self.vocab
50 }
51
52 fn tokenize(&self, text: &str) -> Vec<String> {
53 let mut tokenized_text: Vec<String> = Vec::with_capacity(text.len());
54
55 let temp_text = split_on_special_tokens(text, self.vocab.as_ref());
56
57 for text in temp_text {
58 if !self.vocab.special_values.contains_key(text) {
59 let sub_words: Vec<String> = self.base_tokenizer.tokenize(text);
60
61 for word in sub_words {
62 let cached: bool = match self.cache.borrow().get(&word) {
63 Some(value) => {
64 tokenized_text.extend(value.clone());
65 true
66 }
67 None => false
68 };
69 if !cached {
70 let bpe_output = openai_gpt_bpe(&word, &self.bpe_ranks);
71 self.cache.borrow_mut().insert(word.to_owned(), bpe_output.clone());
72 tokenized_text.extend(bpe_output);
73 }
74 };
75 } else {
76 tokenized_text.push(text.to_owned());
77 }
78 }
79 tokenized_text
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::OpenAiGptVocab;
87 use std::collections::HashMap;
88 use crate::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, TokenizedInput};
89 use crate::preprocessing::vocab::base_vocab::swap_key_values;
90
91 fn generate_test_vocab() -> OpenAiGptVocab {
92 let values: HashMap<String, i64> = [
93 ("t".to_owned(), 0),
94 ("h".to_owned(), 1),
95 ("a</w>".to_owned(), 2),
96 ("n".to_owned(), 3),
97 ("the".to_owned(), 4),
98 ("Ġ".to_owned(), 5),
99 ("<unk>".to_owned(), 6),
100 ("o</w>".to_owned(), 7)
101 ].iter().cloned().collect();
102
103 let special_values: HashMap<String, i64> = [
104 ("<unk>".to_owned(), 6),
105 ].iter().cloned().collect();
106
107 let indices = swap_key_values(&values);
108 let special_indices = swap_key_values(&special_values);
109
110 OpenAiGptVocab { values, indices, unknown_value: "<unk>", special_values, special_indices }
111 }
112
113 fn generate_test_merges() -> BpePairVocab {
114 let values: HashMap<(String, String), i64> = [
115 (("4".to_owned(), "t".to_owned()), 0),
116 (("2".to_owned(), "n".to_owned()), 1),
117 (("r".to_owned(), "th</w>".to_owned()), 2),
118 (("t".to_owned(), "he</w>".to_owned()), 3),
119 (("h".to_owned(), "e".to_owned()), 4),
120 (("t".to_owned(), "h</w>".to_owned()), 5),
121 (("t".to_owned(), "h".to_owned()), 6),
122 ].iter().cloned().collect();
123
124
125 BpePairVocab { values }
126 }
127
128 #[test]
129 fn test_openai_gpt_tokenizer() {
130let vocab = Arc::new(generate_test_vocab());
132 let merges = Rc::new(generate_test_merges());
133 let openai_gpt_tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges);
134 let test_tuples = [
135 (
136 "the earth",
137 vec!("th", "e</w>", "e", "a", "rth</w>")
138 ),
139 (
140 "",
141 vec!()
142 ),
143 (
144 " ",
145 vec!("<unk>")
146 ),
147 (
148 " \n ",
149 vec!("<unk>")
150 ),
151 ];
152 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
153 let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
154
155for (source_text, expected_result) in test_tuples.iter() {
157 assert_eq!(openai_gpt_tokenizer.tokenize(*source_text), *expected_result);
158 }
159
160 assert_eq!(openai_gpt_tokenizer.tokenize_list(source_texts.clone()), expected_results);
161 }
162
163
164 #[test]
165 fn test_encode() {
166let vocab = Arc::new(generate_test_vocab());
168 let merges = Rc::new(generate_test_merges());
169 let openai_gpt_tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges);
170 let truncation_strategy = TruncationStrategy::LongestFirst;
171 let test_tuples = [
172 (
173 "the earth",
174 TokenizedInput { token_ids: vec!(6, 6, 6, 6, 6), segment_ids: vec!(0, 0, 0, 0, 0), special_tokens_mask: vec!(0, 0, 0, 0, 0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
175 ),
176 (
177 " ",
178 TokenizedInput { token_ids: vec!(6), segment_ids: vec!(0), special_tokens_mask: vec!(0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
179 ),
180 (
181 "",
182 TokenizedInput { token_ids: vec!(), segment_ids: vec!(), special_tokens_mask: vec!(), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
183 )
184 ];
185 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
186 let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
187
188for (source_text, expected_result) in test_tuples.iter() {
190 assert_eq!(openai_gpt_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
191 *expected_result);
192 }
193 assert_eq!(openai_gpt_tokenizer.encode_list(source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
194 }
195}