1use crate::OpenAiGptVocab;
15use crate::preprocessing::vocab::base_vocab::Vocab;
16use crate::preprocessing::tokenizer::base_tokenizer::Tokenizer;
17use std::collections::HashMap;
18use crate::preprocessing::tokenizer::tokenization_utils::{ctrl_bpe, split_on_special_tokens};
19use std::rc::Rc;
20use std::cell::RefCell;
21use crate::preprocessing::vocab::bpe_vocab::BpePairVocab;
22use regex::Regex;
23
24
25pub struct CtrlTokenizer {
26 vocab: Rc<OpenAiGptVocab>,
27 bpe_ranks: Rc<BpePairVocab>,
28 cache: RefCell<HashMap<String, Vec<String>>>,
29 regex_pattern: Regex,
30}
31
32impl CtrlTokenizer {
33 pub fn from_file(vocab_path: &str, merges_path: &str) -> CtrlTokenizer {
34 let vocab = Rc::new(OpenAiGptVocab::from_file(vocab_path));
35 let bpe_ranks = Rc::new(BpePairVocab::from_file(merges_path));
36 let cache = RefCell::new(HashMap::new());
37 let regex_pattern = Regex::new(r"\S+\n?").unwrap();
38 CtrlTokenizer { vocab, bpe_ranks, cache, regex_pattern }
39 }
40
41 pub fn from_existing_vocab_and_merges(vocab: Rc<OpenAiGptVocab>, merges: Rc<BpePairVocab>) -> CtrlTokenizer {
42 let cache = RefCell::new(HashMap::new());
43 let regex_pattern = Regex::new(r"\S+\n?").unwrap();
44 CtrlTokenizer { vocab, bpe_ranks: merges, cache, regex_pattern }
45 }
46}
47
48impl Tokenizer<OpenAiGptVocab> for CtrlTokenizer {
49 fn vocab(&self) -> &OpenAiGptVocab {
50 &self.vocab
51 }
52
53 fn tokenize(&self, text: &str) -> Vec<String> {
54 let mut tokenized_text: Vec<String> = Vec::with_capacity(text.len());
55 let temp_text = split_on_special_tokens(text, self.vocab.as_ref());
56 for text in temp_text {
57 if !self.vocab.special_values.contains_key(text) {
58 for word in self.regex_pattern.find_iter(text.as_ref()) {
59 let cached: bool = match self.cache.borrow().get(word.as_str()) {
60 Some(value) => {
61 tokenized_text.extend(value.clone());
62 true
63 }
64 None => false
65 };
66 if !cached {
67 let bpe_output = ctrl_bpe(word.as_str(), &self.bpe_ranks);
68 self.cache.borrow_mut().insert(word.as_str().to_owned(), bpe_output.clone());
69 tokenized_text.extend(bpe_output);
70 }
71 };
72 } else {
73 tokenized_text.push(text.to_owned());
74 }
75 }
76 tokenized_text
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::OpenAiGptVocab;
84 use std::collections::HashMap;
85 use crate::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, TokenizedInput};
86 use crate::preprocessing::vocab::base_vocab::swap_key_values;
87
88 fn generate_test_vocab() -> OpenAiGptVocab {
89 let values: HashMap<String, i64> = [
90 ("t".to_owned(), 0),
91 ("h".to_owned(), 1),
92 ("a@@".to_owned(), 2),
93 ("n".to_owned(), 3),
94 ("the".to_owned(), 4),
95 ("r@@".to_owned(), 5),
96 ("<unk>".to_owned(), 6),
97 ("o@@".to_owned(), 8)
98 ].iter().cloned().collect();
99
100 let special_values: HashMap<String, i64> = [
101 ("<unk>".to_owned(), 6),
102 ].iter().cloned().collect();
103
104 let indices = swap_key_values(&values);
105 let special_indices = swap_key_values(&special_values);
106
107 OpenAiGptVocab { values, indices, unknown_value: "<unk>", special_values, special_indices }
108 }
109
110 fn generate_test_merges() -> BpePairVocab {
111 let values: HashMap<(String, String), i64> = [
112 (("t".to_owned(), "h".to_owned()), 0),
113 (("a".to_owned(), "n".to_owned()), 1),
114 (("i".to_owned(), "n".to_owned()), 2),
115 (("th".to_owned(), "e</w>".to_owned()), 3),
116 (("e".to_owned(), "r".to_owned()), 4),
117 (("r".to_owned(), "e".to_owned()), 5),
118 (("l".to_owned(), "l".to_owned()), 6),
119 ].iter().cloned().collect();
120
121
122 BpePairVocab { values }
123 }
124
125 #[test]
126 fn test_ctrl_tokenizer() {
127let vocab = Rc::new(generate_test_vocab());
129 let merges = Rc::new(generate_test_merges());
130 let ctrl_tokenizer: CtrlTokenizer = CtrlTokenizer::from_existing_vocab_and_merges(vocab, merges);
131 let test_tuples = [
132 (
133 "the earth",
134 vec!("the", "e@@", "a@@", "r@@", "t@@", "h")
135 ),
136 (
137 "Hello, world!",
138 vec!("H@@", "e@@", "ll@@", "o@@", ",", "w@@", "o@@", "r@@", "l@@", "d@@", "!")
139 ),
140 (
141 "",
142 vec!()
143 ),
144 (
145 " ",
146 vec!("<unk>")
147 ),
148 (
149 " \n ",
150 vec!("<unk>")
151 ),
152 ];
153 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
154 let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
155
156for (source_text, expected_result) in test_tuples.iter() {
158 assert_eq!(ctrl_tokenizer.tokenize(*source_text), *expected_result);
159 }
160
161 assert_eq!(ctrl_tokenizer.tokenize_list(source_texts.clone()), expected_results);
162 }
163
164 #[test]
165 fn test_encode() {
166let vocab = Rc::new(generate_test_vocab());
168 let merges = Rc::new(generate_test_merges());
169 let ctrl_tokenizer: CtrlTokenizer = CtrlTokenizer::from_existing_vocab_and_merges(vocab, merges);
170 let truncation_strategy = TruncationStrategy::LongestFirst;
171 let test_tuples = [
172 (
173 "the earth",
174 TokenizedInput { token_ids: vec!(4, 6, 2, 5, 6, 1), segment_ids: vec!(0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(0, 0, 0, 0, 0, 0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
175 ),
176 (
177 "Hello, world!",
178 TokenizedInput { token_ids: vec!(6, 6, 6, 8, 6, 6, 8, 5, 6, 6, 6), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
179 ),
180 (
181 "",
182 TokenizedInput { token_ids: vec!(), segment_ids: vec!(), special_tokens_mask: vec!(), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
183 )
184 ];
185 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
186 let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
187
188for (source_text, expected_result) in test_tuples.iter() {
190 assert_eq!(ctrl_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
191 *expected_result);
192 }
193 assert_eq!(ctrl_tokenizer.encode_list(source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
194 }
195}