rust_tokenizers/tokenizer/
xlm_roberta_tokenizer.rs1use std::path::Path;
14
15use crate::error::TokenizerError;
16use crate::tokenizer::base_tokenizer::{
17 Mask, Offset, OffsetSize, Token, TokenIdsWithOffsets, TokenIdsWithSpecialTokens, TokenRef,
18};
19use crate::tokenizer::tokenization_utils::{
20 clean_text, decompose_nfkc, is_whitespace, lowercase, split_on_special_tokens,
21};
22use crate::tokenizer::{MultiThreadedTokenizer, Tokenizer};
23use crate::vocab::{SentencePieceModel, Vocab, XLMRobertaVocab};
24
25#[allow(clippy::upper_case_acronyms)]
33pub struct XLMRobertaTokenizer {
34 model: SentencePieceModel,
35 vocab: XLMRobertaVocab,
36 lower_case: bool,
37}
38
39impl XLMRobertaTokenizer {
40 pub fn from_file<P: AsRef<Path>>(
55 path: P,
56 lower_case: bool,
57 ) -> Result<XLMRobertaTokenizer, TokenizerError> {
58 let model = SentencePieceModel::from_file(&path)?;
59 let vocab = XLMRobertaVocab::from_file(path)?;
60 Ok(XLMRobertaTokenizer {
61 model,
62 vocab,
63 lower_case,
64 })
65 }
66
67 pub fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
88 path: P,
89 lower_case: bool,
90 special_token_mapping_path: S,
91 ) -> Result<XLMRobertaTokenizer, TokenizerError> {
92 let model = SentencePieceModel::from_file(&path)?;
93 let vocab = XLMRobertaVocab::from_file_with_special_token_mapping(
94 path,
95 special_token_mapping_path,
96 )?;
97 Ok(XLMRobertaTokenizer {
98 model,
99 vocab,
100 lower_case,
101 })
102 }
103
104 pub fn from_existing_vocab_and_model(
123 vocab: XLMRobertaVocab,
124 model: SentencePieceModel,
125 lower_case: bool,
126 ) -> XLMRobertaTokenizer {
127 XLMRobertaTokenizer {
128 model,
129 vocab,
130 lower_case,
131 }
132 }
133}
134
135impl Tokenizer<XLMRobertaVocab> for XLMRobertaTokenizer {
136 fn vocab(&self) -> &XLMRobertaVocab {
137 &self.vocab
138 }
139 fn vocab_mut(&mut self) -> &mut XLMRobertaVocab {
140 &mut self.vocab
141 }
142
143 fn tokenize_to_tokens(&self, text: TokenRef) -> Vec<Token> {
144 let mut tokens = split_on_special_tokens(text, &self.vocab)
145 .into_iter()
146 .map(|token| token.to_owned())
147 .collect::<Vec<Token>>();
148
149 let mut sub_tokens: Vec<Token> = Vec::new();
150 for token in tokens.iter_mut() {
151 if token.mask != Mask::Special && token.mask != Mask::Unknown {
152 clean_text(token, true);
153 decompose_nfkc(token);
154 if self.lower_case {
155 lowercase(token);
156 }
157 token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
158 if !token.text.starts_with('\u{2581}') {
159 token.text.insert(0, '\u{2581}');
160 token.reference_offsets.insert(0, 0);
161 };
162 let output = self.model.decode_forward_token_ref(token.as_ref());
163 let decoded = self.model.decode_backward(&output);
164
165 let output: Vec<Token> = self.model.parse_nodes_to_tokens(decoded);
166 sub_tokens.extend(output)
167 } else {
168 sub_tokens.push(token.clone());
169 }
170 }
171 sub_tokens
172 }
173
174 fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
175 tokens
176 .into_iter()
177 .map(|v| v.replace('\u{2581}', " "))
178 .collect::<Vec<String>>()
179 .join("")
180 }
181
182 fn build_input_with_special_tokens(
183 &self,
184 tokens_ids_with_offsets_1: TokenIdsWithOffsets,
185 tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
186 ) -> TokenIdsWithSpecialTokens {
187 let mut output: Vec<i64> = vec![];
188 let mut token_segment_ids: Vec<i8> = vec![];
189 let mut special_tokens_mask: Vec<i8> = vec![];
190 let mut offsets: Vec<Option<Offset>> = vec![];
191 let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
192 let mut mask: Vec<Mask> = vec![];
193 special_tokens_mask.push(1);
194 special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
195 special_tokens_mask.push(1);
196 token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len() + 2]);
197 output.push(self.vocab.token_to_id(self.vocab.get_cls_value()));
198 output.extend(tokens_ids_with_offsets_1.ids);
199 output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
200 offsets.push(None);
201 offsets.extend(tokens_ids_with_offsets_1.offsets);
202 offsets.push(None);
203 original_offsets.push(vec![]);
204 original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
205 original_offsets.push(vec![]);
206 mask.push(Mask::Special);
207 mask.extend(tokens_ids_with_offsets_1.masks);
208 mask.push(Mask::Special);
209 if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
210 let length = tokens_ids_with_offsets_2_value.ids.len();
211 special_tokens_mask.push(1);
212 special_tokens_mask.extend(vec![0; length]);
213 special_tokens_mask.push(1);
214 token_segment_ids.extend(vec![1; length + 2]);
215 output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
216 output.extend(tokens_ids_with_offsets_2_value.ids);
217 output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
218 offsets.push(None);
219 offsets.extend(tokens_ids_with_offsets_2_value.offsets);
220 original_offsets.push(vec![]);
221 original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
222 offsets.push(None);
223 original_offsets.push(vec![]);
224 mask.push(Mask::Special);
225 mask.extend(tokens_ids_with_offsets_2_value.masks);
226 mask.push(Mask::Special);
227 }
228 TokenIdsWithSpecialTokens {
229 token_ids: output,
230 segment_ids: token_segment_ids,
231 special_tokens_mask,
232 token_offsets: offsets,
233 reference_offsets: original_offsets,
234 mask,
235 }
236 }
237}
238
239impl MultiThreadedTokenizer<XLMRobertaVocab> for XLMRobertaTokenizer {}