Skip to main content

rust_transformers/preprocessing/tokenizer/
ctrl_tokenizer.rs

1// Copyright 2018 Salesforce
2// Copyright 2018 The HuggingFace Inc. team.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::OpenAiGptVocab;
15use crate::preprocessing::vocab::base_vocab::Vocab;
16use crate::preprocessing::tokenizer::base_tokenizer::Tokenizer;
17use std::collections::HashMap;
18use crate::preprocessing::tokenizer::tokenization_utils::{ctrl_bpe, split_on_special_tokens};
19use std::rc::Rc;
20use std::cell::RefCell;
21use crate::preprocessing::vocab::bpe_vocab::BpePairVocab;
22use regex::Regex;
23
24
25pub struct CtrlTokenizer {
26    vocab: Rc<OpenAiGptVocab>,
27    bpe_ranks: Rc<BpePairVocab>,
28    cache: RefCell<HashMap<String, Vec<String>>>,
29    regex_pattern: Regex,
30}
31
32impl CtrlTokenizer {
33    pub fn from_file(vocab_path: &str, merges_path: &str) -> CtrlTokenizer {
34        let vocab = Rc::new(OpenAiGptVocab::from_file(vocab_path));
35        let bpe_ranks = Rc::new(BpePairVocab::from_file(merges_path));
36        let cache = RefCell::new(HashMap::new());
37        let regex_pattern = Regex::new(r"\S+\n?").unwrap();
38        CtrlTokenizer { vocab, bpe_ranks, cache, regex_pattern }
39    }
40
41    pub fn from_existing_vocab_and_merges(vocab: Rc<OpenAiGptVocab>, merges: Rc<BpePairVocab>) -> CtrlTokenizer {
42        let cache = RefCell::new(HashMap::new());
43        let regex_pattern = Regex::new(r"\S+\n?").unwrap();
44        CtrlTokenizer { vocab, bpe_ranks: merges, cache, regex_pattern }
45    }
46}
47
48impl Tokenizer<OpenAiGptVocab> for CtrlTokenizer {
49    fn vocab(&self) -> &OpenAiGptVocab {
50        &self.vocab
51    }
52
53    fn tokenize(&self, text: &str) -> Vec<String> {
54        let mut tokenized_text: Vec<String> = Vec::with_capacity(text.len());
55        let temp_text = split_on_special_tokens(text, self.vocab.as_ref());
56        for text in temp_text {
57            if !self.vocab.special_values.contains_key(text) {
58                for word in self.regex_pattern.find_iter(text.as_ref()) {
59                    let cached: bool = match self.cache.borrow().get(word.as_str()) {
60                        Some(value) => {
61                            tokenized_text.extend(value.clone());
62                            true
63                        }
64                        None => false
65                    };
66                    if !cached {
67                        let bpe_output = ctrl_bpe(word.as_str(), &self.bpe_ranks);
68                        self.cache.borrow_mut().insert(word.as_str().to_owned(), bpe_output.clone());
69                        tokenized_text.extend(bpe_output);
70                    }
71                };
72            } else {
73                tokenized_text.push(text.to_owned());
74            }
75        }
76        tokenized_text
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::OpenAiGptVocab;
84    use std::collections::HashMap;
85    use crate::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, TokenizedInput};
86    use crate::preprocessing::vocab::base_vocab::swap_key_values;
87
88    fn generate_test_vocab() -> OpenAiGptVocab {
89        let values: HashMap<String, i64> = [
90            ("t".to_owned(), 0),
91            ("h".to_owned(), 1),
92            ("a@@".to_owned(), 2),
93            ("n".to_owned(), 3),
94            ("the".to_owned(), 4),
95            ("r@@".to_owned(), 5),
96            ("<unk>".to_owned(), 6),
97            ("o@@".to_owned(), 8)
98        ].iter().cloned().collect();
99
100        let special_values: HashMap<String, i64> = [
101            ("<unk>".to_owned(), 6),
102        ].iter().cloned().collect();
103
104        let indices = swap_key_values(&values);
105        let special_indices = swap_key_values(&special_values);
106
107        OpenAiGptVocab { values, indices, unknown_value: "<unk>", special_values, special_indices }
108    }
109
110    fn generate_test_merges() -> BpePairVocab {
111        let values: HashMap<(String, String), i64> = [
112            (("t".to_owned(), "h".to_owned()), 0),
113            (("a".to_owned(), "n".to_owned()), 1),
114            (("i".to_owned(), "n".to_owned()), 2),
115            (("th".to_owned(), "e</w>".to_owned()), 3),
116            (("e".to_owned(), "r".to_owned()), 4),
117            (("r".to_owned(), "e".to_owned()), 5),
118            (("l".to_owned(), "l".to_owned()), 6),
119        ].iter().cloned().collect();
120
121
122        BpePairVocab { values }
123    }
124
125    #[test]
126    fn test_ctrl_tokenizer() {
127//        Given
128        let vocab = Rc::new(generate_test_vocab());
129        let merges = Rc::new(generate_test_merges());
130        let ctrl_tokenizer: CtrlTokenizer = CtrlTokenizer::from_existing_vocab_and_merges(vocab, merges);
131        let test_tuples = [
132            (
133                "the earth",
134                vec!("the", "e@@", "a@@", "r@@", "t@@", "h")
135            ),
136            (
137                "Hello, world!",
138                vec!("H@@", "e@@", "ll@@", "o@@", ",", "w@@", "o@@", "r@@", "l@@", "d@@", "!")
139            ),
140            (
141                "",
142                vec!()
143            ),
144            (
145                " ",
146                vec!("<unk>")
147            ),
148            (
149                " \n ",
150                vec!("<unk>")
151            ),
152        ];
153        let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
154        let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
155
156//        When & Then
157        for (source_text, expected_result) in test_tuples.iter() {
158            assert_eq!(ctrl_tokenizer.tokenize(*source_text), *expected_result);
159        }
160
161        assert_eq!(ctrl_tokenizer.tokenize_list(source_texts.clone()), expected_results);
162    }
163
164    #[test]
165    fn test_encode() {
166//        Given
167        let vocab = Rc::new(generate_test_vocab());
168        let merges = Rc::new(generate_test_merges());
169        let ctrl_tokenizer: CtrlTokenizer = CtrlTokenizer::from_existing_vocab_and_merges(vocab, merges);
170        let truncation_strategy = TruncationStrategy::LongestFirst;
171        let test_tuples = [
172            (
173                "the earth",
174                TokenizedInput { token_ids: vec!(4, 6, 2, 5, 6, 1), segment_ids: vec!(0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(0, 0, 0, 0, 0, 0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
175            ),
176            (
177                "Hello, world!",
178                TokenizedInput { token_ids: vec!(6, 6, 6, 8, 6, 6, 8, 5, 6, 6, 6), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
179            ),
180            (
181                "",
182                TokenizedInput { token_ids: vec!(), segment_ids: vec!(), special_tokens_mask: vec!(), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
183            )
184        ];
185        let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
186        let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
187
188//        When & Then
189        for (source_text, expected_result) in test_tuples.iter() {
190            assert_eq!(ctrl_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
191                       *expected_result);
192        }
193        assert_eq!(ctrl_tokenizer.encode_list(source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
194    }
195}