rust_tokenizers/tokenizer/
xlnet_tokenizer.rs1use std::path::Path;
14
15use crate::error::TokenizerError;
16use crate::tokenizer::base_tokenizer::{TokenIdsWithOffsets, TokenIdsWithSpecialTokens};
17use crate::tokenizer::tokenization_utils::strip_accents;
18use crate::tokenizer::tokenization_utils::{
19 clean_text, decompose_nfkc, is_whitespace, lowercase, replace_string, split_on_special_tokens,
20};
21use crate::tokenizer::{MultiThreadedTokenizer, Tokenizer};
22use crate::vocab::{SentencePieceModel, Vocab, XLNetVocab};
23use crate::{Mask, Offset, OffsetSize, Token, TokenRef};
24
25#[allow(clippy::upper_case_acronyms)]
34pub struct XLNetTokenizer {
35 model: SentencePieceModel,
36 vocab: XLNetVocab,
37 lower_case: bool,
38 strip_accents: bool,
39}
40
41impl XLNetTokenizer {
42 pub fn from_file<P: AsRef<Path>>(
60 path: P,
61 lower_case: bool,
62 strip_accents: bool,
63 ) -> Result<XLNetTokenizer, TokenizerError> {
64 let model = SentencePieceModel::from_file(&path)?;
65 let vocab = XLNetVocab::from_file(path)?;
66 Ok(XLNetTokenizer {
67 model,
68 vocab,
69 lower_case,
70 strip_accents,
71 })
72 }
73
74 pub fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
98 path: P,
99 lower_case: bool,
100 strip_accents: bool,
101 special_token_mapping_path: S,
102 ) -> Result<XLNetTokenizer, TokenizerError> {
103 let model = SentencePieceModel::from_file(&path)?;
104 let vocab =
105 XLNetVocab::from_file_with_special_token_mapping(path, special_token_mapping_path)?;
106 Ok(XLNetTokenizer {
107 model,
108 vocab,
109 lower_case,
110 strip_accents,
111 })
112 }
113
114 pub fn from_existing_vocab_and_model(
136 vocab: XLNetVocab,
137 model: SentencePieceModel,
138 lower_case: bool,
139 strip_accents: bool,
140 ) -> XLNetTokenizer {
141 XLNetTokenizer {
142 model,
143 vocab,
144 lower_case,
145 strip_accents,
146 }
147 }
148
149 fn post_process_pieces<'a>(&self, tokens: &'a mut Vec<Token>) -> &'a Vec<Token> {
150 let mut positions_to_update: Vec<(usize, Vec<Token>)> = vec![];
151 for (token_idx, token) in tokens.iter().enumerate() {
152 if token.text.chars().count() > 1 {
153 let mut token_chars = token.text.chars().rev();
154 if (token_chars.next().unwrap() == ',')
155 & token_chars.next().unwrap().is_ascii_digit()
156 {
157 let mut new_token = token.clone();
158 let last_char = new_token.text.pop().unwrap();
159 let updated_tokens = self.model.decode_forward_token_ref(new_token.as_ref());
160 let updated_tokens = self.model.decode_backward(&updated_tokens);
161 let mut updated_tokens = self.model.parse_nodes_to_tokens(updated_tokens);
162 if !token.text.starts_with('\u{2581}')
163 & updated_tokens[0].text.starts_with('\u{2581}')
164 {
165 if updated_tokens[0].text.chars().count() == 1 {
166 updated_tokens.remove(0);
167 } else {
168 let first_char_length =
169 updated_tokens[0].text.chars().next().unwrap().len_utf8();
170 updated_tokens[0].text = (updated_tokens[0].text[first_char_length..])
171 .parse()
172 .unwrap();
173 }
174 }
175 updated_tokens.push(Token {
176 text: last_char.to_string(),
177 offset: Offset {
178 begin: token.offset.end,
179 end: token.offset.end,
180 },
181 reference_offsets: vec![*token.reference_offsets.last().unwrap()],
182 mask: token.mask,
183 });
184 positions_to_update.push((token_idx, updated_tokens.clone()));
185 }
186 }
187 }
188 for (pos, new_tokens) in positions_to_update {
189 tokens.splice(pos..pos + 1, new_tokens);
190 }
191 tokens
192 }
193}
194
195impl Tokenizer<XLNetVocab> for XLNetTokenizer {
196 fn vocab(&self) -> &XLNetVocab {
197 &self.vocab
198 }
199 fn vocab_mut(&mut self) -> &mut XLNetVocab {
200 &mut self.vocab
201 }
202
203 fn tokenize_to_tokens(&self, text: TokenRef) -> Vec<Token> {
204 let mut tokens = split_on_special_tokens(text, &self.vocab)
205 .into_iter()
206 .map(|token| token.to_owned())
207 .collect::<Vec<Token>>();
208
209 let mut sub_tokens: Vec<Token> = Vec::new();
210 for token in tokens.iter_mut() {
211 if token.mask != Mask::Special && token.mask != Mask::Unknown {
212 replace_string(token, "``", "\"");
213 replace_string(token, "\'\'", "\"");
214 clean_text(token, true);
215 decompose_nfkc(token);
216 if self.lower_case {
217 lowercase(token);
218 }
219 if self.strip_accents {
220 strip_accents(token);
221 }
222 token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
223 if !token.text.starts_with('\u{2581}') {
224 token.text.insert(0, '\u{2581}');
225 token.reference_offsets.insert(0, 0);
226 };
227 let output = self.model.decode_forward_token_ref(token.as_ref());
228 let decoded = self.model.decode_backward(&output);
229
230 let mut output: Vec<Token> = self.model.parse_nodes_to_tokens(decoded);
231 self.post_process_pieces(&mut output);
232 sub_tokens.extend(output)
233 } else {
234 sub_tokens.push(token.clone());
235 }
236 }
237 sub_tokens
238 }
239
240 fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
241 tokens
242 .into_iter()
243 .map(|v| v.replace('\u{2581}', " "))
244 .collect::<Vec<String>>()
245 .join("")
246 }
247
248 fn build_input_with_special_tokens(
249 &self,
250 tokens_ids_with_offsets_1: TokenIdsWithOffsets,
251 tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
252 ) -> TokenIdsWithSpecialTokens {
253 let mut output: Vec<i64> = vec![];
254 let mut token_segment_ids: Vec<i8> = vec![];
255 let mut special_tokens_mask: Vec<i8> = vec![];
256 let mut offsets: Vec<Option<Offset>> = vec![];
257 let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
258 let mut mask: Vec<Mask> = vec![];
259 special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
261 special_tokens_mask.push(1);
262 token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len() + 1]);
263 output.extend(tokens_ids_with_offsets_1.ids);
264 output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
265 offsets.extend(tokens_ids_with_offsets_1.offsets);
266 offsets.push(None);
267 original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
268 original_offsets.push(vec![]);
269 mask.extend(tokens_ids_with_offsets_1.masks);
270 mask.push(Mask::Special);
271 if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
273 let length = tokens_ids_with_offsets_2_value.ids.len();
274 special_tokens_mask.extend(vec![0; length]);
275 special_tokens_mask.push(1);
276 token_segment_ids.extend(vec![1; length + 1]);
277 output.extend(tokens_ids_with_offsets_2_value.ids);
278 output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
279 offsets.extend(tokens_ids_with_offsets_2_value.offsets);
280 original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
281 offsets.push(None);
282 original_offsets.push(vec![]);
283 mask.extend(tokens_ids_with_offsets_2_value.masks);
284 mask.push(Mask::Special);
285 }
286 output.push(self.vocab.token_to_id(self.vocab.get_cls_value()));
288 special_tokens_mask.push(1);
289 offsets.push(None);
290 original_offsets.push(vec![]);
291 mask.push(Mask::Special);
292 TokenIdsWithSpecialTokens {
293 token_ids: output,
294 segment_ids: token_segment_ids,
295 special_tokens_mask,
296 token_offsets: offsets,
297 reference_offsets: original_offsets,
298 mask,
299 }
300 }
301}
302
303impl MultiThreadedTokenizer<XLNetVocab> for XLNetTokenizer {}