rust_tokenizers/tokenizer/
xlm_roberta_tokenizer.rs

1// Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
2// Copyright 2019-2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::path::Path;
14
15use crate::error::TokenizerError;
16use crate::tokenizer::base_tokenizer::{
17    Mask, Offset, OffsetSize, Token, TokenIdsWithOffsets, TokenIdsWithSpecialTokens, TokenRef,
18};
19use crate::tokenizer::tokenization_utils::{
20    clean_text, decompose_nfkc, is_whitespace, lowercase, split_on_special_tokens,
21};
22use crate::tokenizer::{MultiThreadedTokenizer, Tokenizer};
23use crate::vocab::{SentencePieceModel, Vocab, XLMRobertaVocab};
24
25/// # XLM RoBERTa tokenizer
26/// XLM RoBERTa tokenizer performing:
27/// - Splitting on special tokens
28/// - text cleaning
29/// - NFKC decomposition
30/// - (optional) lower casing
31/// - SentencePiece decomposition
32#[allow(clippy::upper_case_acronyms)]
33pub struct XLMRobertaTokenizer {
34    model: SentencePieceModel,
35    vocab: XLMRobertaVocab,
36    lower_case: bool,
37}
38
39impl XLMRobertaTokenizer {
40    /// Create a new instance of a `XLMRobertaTokenizer`
41    /// Expects a json vocab file and a SentencePiece protobuf file as an input.
42    ///
43    /// # Parameters
44    /// - path (`&str`): path to the SentencePiece model file
45    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
46    ///
47    /// # Example
48    ///
49    /// ```no_run
50    /// use rust_tokenizers::tokenizer::{Tokenizer, XLMRobertaTokenizer};
51    /// let lower_case = false;
52    /// let tokenizer = XLMRobertaTokenizer::from_file("path/to/vocab/file", lower_case).unwrap();
53    /// ```
54    pub fn from_file<P: AsRef<Path>>(
55        path: P,
56        lower_case: bool,
57    ) -> Result<XLMRobertaTokenizer, TokenizerError> {
58        let model = SentencePieceModel::from_file(&path)?;
59        let vocab = XLMRobertaVocab::from_file(path)?;
60        Ok(XLMRobertaTokenizer {
61            model,
62            vocab,
63            lower_case,
64        })
65    }
66
67    /// Create a new instance of a `XLMRobertaTokenizer`
68    /// Expects a json vocab file and a SentencePiece protobuf file and special token mapping file as inputs.
69    ///
70    /// # Parameters
71    /// - path (`&str`): path to the SentencePiece model file
72    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
73    /// - special_token_mapping_path (`&str`): path to a special token mapping file to overwrite default special tokens
74    ///
75    /// # Example
76    ///
77    /// ```no_run
78    /// use rust_tokenizers::tokenizer::{Tokenizer, XLMRobertaTokenizer};
79    /// let lower_case = false;
80    /// let tokenizer = XLMRobertaTokenizer::from_file_with_special_token_mapping(
81    ///     "path/to/vocab/file",
82    ///     lower_case,
83    ///     "path/to/special/token/mapping/file",
84    /// )
85    /// .unwrap();
86    /// ```
87    pub fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
88        path: P,
89        lower_case: bool,
90        special_token_mapping_path: S,
91    ) -> Result<XLMRobertaTokenizer, TokenizerError> {
92        let model = SentencePieceModel::from_file(&path)?;
93        let vocab = XLMRobertaVocab::from_file_with_special_token_mapping(
94            path,
95            special_token_mapping_path,
96        )?;
97        Ok(XLMRobertaTokenizer {
98            model,
99            vocab,
100            lower_case,
101        })
102    }
103
104    /// Create a new instance of a `MarianTokenizer` from an existing vocabulary and model
105    ///
106    /// # Parameters
107    /// - vocab (`XLMRobertaVocab`): vocabulary
108    /// - model (`SentencePieceModel`): SentencePiece model
109    /// - lower_case (`bool`): flag indicating if the text should be lower-cased as part of the tokenization
110    ///
111    /// # Example
112    ///
113    /// ```no_run
114    /// use rust_tokenizers::tokenizer::{Tokenizer, XLMRobertaTokenizer};
115    /// use rust_tokenizers::vocab::{SentencePieceModel, Vocab, XLMRobertaVocab};
116    /// let lower_case = false;
117    /// let vocab = XLMRobertaVocab::from_file("path/to/vocab/file").unwrap();
118    /// let model = SentencePieceModel::from_file("path/to/model/file").unwrap();
119    ///
120    /// let tokenizer = XLMRobertaTokenizer::from_existing_vocab_and_model(vocab, model, lower_case);
121    /// ```
122    pub fn from_existing_vocab_and_model(
123        vocab: XLMRobertaVocab,
124        model: SentencePieceModel,
125        lower_case: bool,
126    ) -> XLMRobertaTokenizer {
127        XLMRobertaTokenizer {
128            model,
129            vocab,
130            lower_case,
131        }
132    }
133}
134
135impl Tokenizer<XLMRobertaVocab> for XLMRobertaTokenizer {
136    fn vocab(&self) -> &XLMRobertaVocab {
137        &self.vocab
138    }
139    fn vocab_mut(&mut self) -> &mut XLMRobertaVocab {
140        &mut self.vocab
141    }
142
143    fn tokenize_to_tokens(&self, text: TokenRef) -> Vec<Token> {
144        let mut tokens = split_on_special_tokens(text, &self.vocab)
145            .into_iter()
146            .map(|token| token.to_owned())
147            .collect::<Vec<Token>>();
148
149        let mut sub_tokens: Vec<Token> = Vec::new();
150        for token in tokens.iter_mut() {
151            if token.mask != Mask::Special && token.mask != Mask::Unknown {
152                clean_text(token, true);
153                decompose_nfkc(token);
154                if self.lower_case {
155                    lowercase(token);
156                }
157                token.text = token.text.replace(|c: char| is_whitespace(&c), "\u{2581}");
158                if !token.text.starts_with('\u{2581}') {
159                    token.text.insert(0, '\u{2581}');
160                    token.reference_offsets.insert(0, 0);
161                };
162                let output = self.model.decode_forward_token_ref(token.as_ref());
163                let decoded = self.model.decode_backward(&output);
164
165                let output: Vec<Token> = self.model.parse_nodes_to_tokens(decoded);
166                sub_tokens.extend(output)
167            } else {
168                sub_tokens.push(token.clone());
169            }
170        }
171        sub_tokens
172    }
173
174    fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
175        tokens
176            .into_iter()
177            .map(|v| v.replace('\u{2581}', " "))
178            .collect::<Vec<String>>()
179            .join("")
180    }
181
182    fn build_input_with_special_tokens(
183        &self,
184        tokens_ids_with_offsets_1: TokenIdsWithOffsets,
185        tokens_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
186    ) -> TokenIdsWithSpecialTokens {
187        let mut output: Vec<i64> = vec![];
188        let mut token_segment_ids: Vec<i8> = vec![];
189        let mut special_tokens_mask: Vec<i8> = vec![];
190        let mut offsets: Vec<Option<Offset>> = vec![];
191        let mut original_offsets: Vec<Vec<OffsetSize>> = vec![];
192        let mut mask: Vec<Mask> = vec![];
193        special_tokens_mask.push(1);
194        special_tokens_mask.extend(vec![0; tokens_ids_with_offsets_1.ids.len()]);
195        special_tokens_mask.push(1);
196        token_segment_ids.extend(vec![0; tokens_ids_with_offsets_1.ids.len() + 2]);
197        output.push(self.vocab.token_to_id(self.vocab.get_cls_value()));
198        output.extend(tokens_ids_with_offsets_1.ids);
199        output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
200        offsets.push(None);
201        offsets.extend(tokens_ids_with_offsets_1.offsets);
202        offsets.push(None);
203        original_offsets.push(vec![]);
204        original_offsets.extend(tokens_ids_with_offsets_1.reference_offsets);
205        original_offsets.push(vec![]);
206        mask.push(Mask::Special);
207        mask.extend(tokens_ids_with_offsets_1.masks);
208        mask.push(Mask::Special);
209        if let Some(tokens_ids_with_offsets_2_value) = tokens_ids_with_offsets_2 {
210            let length = tokens_ids_with_offsets_2_value.ids.len();
211            special_tokens_mask.push(1);
212            special_tokens_mask.extend(vec![0; length]);
213            special_tokens_mask.push(1);
214            token_segment_ids.extend(vec![1; length + 2]);
215            output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
216            output.extend(tokens_ids_with_offsets_2_value.ids);
217            output.push(self.vocab.token_to_id(self.vocab.get_sep_value()));
218            offsets.push(None);
219            offsets.extend(tokens_ids_with_offsets_2_value.offsets);
220            original_offsets.push(vec![]);
221            original_offsets.extend(tokens_ids_with_offsets_2_value.reference_offsets);
222            offsets.push(None);
223            original_offsets.push(vec![]);
224            mask.push(Mask::Special);
225            mask.extend(tokens_ids_with_offsets_2_value.masks);
226            mask.push(Mask::Special);
227        }
228        TokenIdsWithSpecialTokens {
229            token_ids: output,
230            segment_ids: token_segment_ids,
231            special_tokens_mask,
232            token_offsets: offsets,
233            reference_offsets: original_offsets,
234            mask,
235        }
236    }
237}
238
239impl MultiThreadedTokenizer<XLMRobertaVocab> for XLMRobertaTokenizer {}