rust_tokenizers/tokenizer/
xlnet_tokenizer.rs

1// Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2// Copyright 2019-2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::path::Path;
14
15use crate::error::TokenizerError;
16use crate::tokenizer::base_tokenizer::{TokenIdsWithOffsets, TokenIdsWithSpecialTokens};
17use crate::tokenizer::tokenization_utils::strip_accents;
18use crate::tokenizer::tokenization_utils::{
19    clean_text, decompose_nfkc, is_whitespace, lowercase, replace_string, split_on_special_tokens,
20};
21use crate::tokenizer::{MultiThreadedTokenizer, Tokenizer};
22use crate::vocab::{SentencePieceModel, Vocab, XLNetVocab};
23use crate::{Mask, Offset, OffsetSize, Token, TokenRef};
24
25/// # XLNet tokenizer
26/// XLNet tokenizer performing:
27/// - Splitting on special tokens
28/// - Text cleaning
29/// - NFKC decomposition
30/// - (optional) lower casing
31/// - (optional) accents stripping
32/// - SentencePiece decomposition
33#[allow(clippy::upper_case_acronyms)]
34pub struct XLNetTokenizer {
35    model: SentencePieceModel,
36    vocab: XLNetVocab,
37    lower_case: bool,
38    strip_accents: bool,
39}
40
41impl XLNetTokenizer {
42    /// Create a new instance of a `XLNetTokenizer`
43    /// Expects a SentencePiece protobuf file as an input.
44    ///
45    /// # Parameters
46    /// - path (`&str`): path to the SentencePiece model file
47    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
48    /// - strip_accents (`bool`): flag indicating if accents should be stripped from the text
49    ///
50    /// # Example
51    ///
52    /// ```no_run
53    /// use rust_tokenizers::tokenizer::{Tokenizer, XLNetTokenizer};
54    /// let lower_case = false;
55    /// let strip_accents = false;
56    /// let tokenizer =
57    ///     XLNetTokenizer::from_file("path/to/vocab/file", lower_case, strip_accents).unwrap();
58    /// ```
59    pub fn from_file<P: AsRef<Path>>(
60        path: P,
61        lower_case: bool,
62        strip_accents: bool,
63    ) -> Result<XLNetTokenizer, TokenizerError> {
64        let model = SentencePieceModel::from_file(&path)?;
65        let vocab = XLNetVocab::from_file(path)?;
66        Ok(XLNetTokenizer {
67            model,
68            vocab,
69            lower_case,
70            strip_accents,
71        })
72    }
73
74    /// Create a new instance of a `XLNetTokenizer`
75    /// Expects a SentencePiece protobuf file and special token mapping file as inputs.
76    ///
77    /// # Parameters
78    /// - path (`&str`): path to the SentencePiece model file
79    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
80    /// - strip_accents (`bool`): flag indicating if accents should be stripped from the text
81    /// - special_token_mapping_path (`&str`): path to a special token mapping file to overwrite default special tokens
82    ///
83    /// # Example
84    ///
85    /// ```no_run
86    /// use rust_tokenizers::tokenizer::{Tokenizer, XLNetTokenizer};
87    /// let lower_case = false;
88    /// let strip_accents = false;
89    /// let tokenizer = XLNetTokenizer::from_file_with_special_token_mapping(
90    ///     "path/to/vocab/file",
91    ///     lower_case,
92    ///     strip_accents,
93    ///     "path/to/special/token/mapping/file",
94    /// )
95    /// .unwrap();
96    /// ```
97    pub fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
98        path: P,
99        lower_case: bool,
100        strip_accents: bool,
101        special_token_mapping_path: S,
102    ) -> Result<XLNetTokenizer, TokenizerError> {
103        let model = SentencePieceModel::from_file(&path)?;
104        let vocab =
105            XLNetVocab::from_file_with_special_token_mapping(path, special_token_mapping_path)?;
106        Ok(XLNetTokenizer {
107            model,
108            vocab,
109            lower_case,
110            strip_accents,
111        })
112    }
113
114    /// Create a new instance of a `XLNetTokenizer` from an existing vocabulary and model
115    ///
116    /// # Parameters
117    /// - vocab (`XLNetVocab`): vocabulary
118    /// - model (`SentencePieceModel`): SentencePiece model
119    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
120    /// - strip_accents (`bool`): flag indicating if accents should be stripped from the text
121    ///
122    /// # Example
123    ///
124    /// ```no_run
125    /// use rust_tokenizers::tokenizer::{Tokenizer, XLNetTokenizer};
126    /// use rust_tokenizers::vocab::{SentencePieceModel, Vocab, XLNetVocab};
127    /// let lower_case = false;
128    /// let strip_accents = false;
129    /// let vocab = XLNetVocab::from_file("path/to/vocab/file").unwrap();
130    /// let model = SentencePieceModel::from_file("path/to/model/file").unwrap();
131    ///
132    /// let tokenizer =
133    ///     XLNetTokenizer::from_existing_vocab_and_model(vocab, model, lower_case, strip_accents);
134    /// ```
135    pub fn from_existing_vocab_and_model(
136        vocab: XLNetVocab,
137        model: SentencePieceModel,
138        lower_case: bool,
139        strip_accents: bool,
140    ) -> XLNetTokenizer {
141        XLNetTokenizer {
142            model,
143            vocab,
144            lower_case,
145            strip_accents,
146        }
147    }
148
149    fn post_process_pieces<'a>(&self, tokens: &'a mut Vec<Token>) -> &'a Vec<Token> {
150        let mut positions_to_update: Vec<(usize, Vec<Token>)> = vec![];
151        for (token_idx, token) in tokens.iter().enumerate() {
152            if token.text.chars().count() > 1 {
153                let mut token_chars = token.text.chars().rev();
154                if (token_chars.next().unwrap() == ',')
155                    & token_chars.next().unwrap().is_ascii_digit()
156                {
157                    let mut new_token = token.clone();
158                    let last_char = new_token.text.pop().unwrap();
159                    let updated_tokens = self.model.decode_forward_token_ref(new_token.as_ref());
160                    let updated_tokens = self.model.decode_backward(&updated_tokens);
161                    let mut updated_tokens = self.model.parse_nodes_to_tokens(updated_tokens);
162                    if !token.text.starts_with('\u{2581}')
163                        & updated_tokens[0].text.starts_with('\u{2581}')
164                    {
165                        if updated_tokens[0].text.chars().count() == 1 {
166                            updated_tokens.remove(0);
167                        } else {
168                            let first_char_length =
169                                updated_tokens[0].text.chars().next().unwrap().len_utf8();
170                            updated_tokens[0].text = (updated_tokens[0].text[first_char_length..])
171                                .parse()
172                                .unwrap();
173                        }
174                    }
175                    updated_tokens.push(Token {
176                        text: last_char.to_string(),
177                        offset: Offset {
178                            begin: token.offset.end,
179                            end: token.offset.end,
180                        },
181                        reference_offsets: vec![*token.reference_offsets.last().unwrap()],
182                        mask: token.mask,
183                    });
184                    positions_to_update.push((token_idx, updated_tokens.clone()));
185                }
186            }
187        }
188        for (pos, new_tokens) in positions_to_update {
189            tokens.splice(pos..pos + 1, new_tokens);
190        }
191        tokens
192    }
193}
194
195impl Tokenizer<XLNetVocab> for XLNetTokenizer {
196    fn vocab(&self) -> &XLNetVocab {
197        &self.vocab
198    }
199    fn vocab_mut(&mut self) -> &mut XLNetVocab {
200        &mut self.vocab
201    }
202
203    fn tokenize_to_tokens(&self, text: TokenRef) -> Vec<Token> {
204        let mut tokens = split_on_special_tokens(text, &self.vocab)
205            .into_iter()
206            .map(|token| token.to_owned())
207            .collect::<Vec<Token>>();
208
209        let mut sub_tokens: Vec<Token> = Vec::new();
210        for token in tokens.iter_mut() {
211            if token.mask != Mask::Special && token.mask != Mask::Unknown {
212                replace_string(token, "``", "\"");
213                replace_string(token, "\'\'", "\"");
214                clean_text(token, true);
215                decompose_nfkc(token);
216                if self.lower_case {
217                    lowercase(token);
218                }
219                if self.strip_accents {
220                    strip_accents(token);
221                }
222                token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
223                if !token.text.starts_with('\u{2581}') {
224                    token.text.insert(0, '\u{2581}');
225                    token.reference_offsets.insert(0, 0);
226                };
227                let output = self.model.decode_forward_token_ref(token.as_ref());
228                let decoded = self.model.decode_backward(&output);
229
230                let mut output: Vec<Token> = self.model.parse_nodes_to_tokens(decoded);
231                self.post_process_pieces(&mut output);
232                sub_tokens.extend(output)
233            } else {
234                sub_tokens.push(token.clone());
235            }
236        }
237        sub_tokens
238    }
239
240    fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
241        tokens
242            .into_iter()
243            .map(|v| v.replace('\u{2581}', " "))
244            .collect::<Vec<String>>()
245            .join("")
246    }
247
248    fn build_input_with_special_tokens(
249        &self,
250        tokens_ids_with_offsets_1: TokenIdsWithOffsets,
251        tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
252    ) -> TokenIdsWithSpecialTokens {
253        let mut output: Vec<i64> = vec![];
254        let mut token_segment_ids: Vec<i8> = vec![];
255        let mut special_tokens_mask: Vec<i8> = vec![];
256        let mut offsets: Vec<Option<Offset>> = vec![];
257        let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
258        let mut mask: Vec<Mask> = vec![];
259        // Push the first sequence with a SEP token
260        special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
261        special_tokens_mask.push(1);
262        token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len() + 1]);
263        output.extend(tokens_ids_with_offsets_1.ids);
264        output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
265        offsets.extend(tokens_ids_with_offsets_1.offsets);
266        offsets.push(None);
267        original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
268        original_offsets.push(vec![]);
269        mask.extend(tokens_ids_with_offsets_1.masks);
270        mask.push(Mask::Special);
271        // Push the second sequence with a SEP token if provided
272        if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
273            let length = tokens_ids_with_offsets_2_value.ids.len();
274            special_tokens_mask.extend(vec![0; length]);
275            special_tokens_mask.push(1);
276            token_segment_ids.extend(vec![1; length + 1]);
277            output.extend(tokens_ids_with_offsets_2_value.ids);
278            output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
279            offsets.extend(tokens_ids_with_offsets_2_value.offsets);
280            original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
281            offsets.push(None);
282            original_offsets.push(vec![]);
283            mask.extend(tokens_ids_with_offsets_2_value.masks);
284            mask.push(Mask::Special);
285        }
286        // Push the CLS token at the end of the sequence
287        output.push(self.vocab.token_to_id(self.vocab.get_cls_value()));
288        special_tokens_mask.push(1);
289        offsets.push(None);
290        original_offsets.push(vec![]);
291        mask.push(Mask::Special);
292        TokenIdsWithSpecialTokens {
293            token_ids: output,
294            segment_ids: token_segment_ids,
295            special_tokens_mask,
296            token_offsets: offsets,
297            reference_offsets: original_offsets,
298            mask,
299        }
300    }
301}
302
303impl MultiThreadedTokenizer<XLNetVocab> for XLNetTokenizer {}