tokenizers/models/wordpiece/
mod.rs

1//! [WordPiece](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf)
2//! model.
3
4use crate::models::bpe::BPE;
5use crate::tokenizer::{Model, Result, Token};
6use std::{
7    borrow::Cow,
8    collections::HashMap,
9    fs::File,
10    io::prelude::*,
11    io::{BufRead, BufReader},
12    path::{Path, PathBuf},
13};
14
15mod serialization;
16mod trainer;
17pub use trainer::*;
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21    #[error("WordPiece error: Missing [UNK] token from the vocabulary")]
22    MissingUnkToken,
23}
24
25type Vocab = HashMap<String, u32>;
26type VocabR = HashMap<u32, String>;
27
28struct Config {
29    files: Option<String>,
30    vocab: Vocab,
31    unk_token: String,
32    continuing_subword_prefix: String,
33    max_input_chars_per_word: usize,
34}
35
36/// A `WordPieceBuilder` can be used to create a `WordPiece` model with a custom configuration.
37pub struct WordPieceBuilder {
38    config: Config,
39}
40
41impl Default for WordPieceBuilder {
42    fn default() -> Self {
43        Self {
44            config: Config {
45                files: None,
46                vocab: HashMap::new(),
47                unk_token: String::from("[UNK]"),
48                continuing_subword_prefix: String::from("##"),
49                max_input_chars_per_word: 100,
50            },
51        }
52    }
53}
54
55impl WordPieceBuilder {
56    /// Construct a new `WordPieceBuilder`.
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Set the input files.
62    #[must_use]
63    pub fn files(mut self, vocab: String) -> Self {
64        self.config.files = Some(vocab);
65        self
66    }
67
68    /// Set the vocab (token -> ID) mapping.
69    #[must_use]
70    pub fn vocab(mut self, vocab: Vocab) -> Self {
71        self.config.vocab = vocab;
72        self
73    }
74
75    /// The the `UNK` token for the vocab.
76    #[must_use]
77    pub fn unk_token(mut self, unk_token: String) -> Self {
78        self.config.unk_token = unk_token;
79        self
80    }
81
82    /// Set the prefix for continuing subwords.
83    #[must_use]
84    pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self {
85        self.config.continuing_subword_prefix = continuing_subword_prefix;
86        self
87    }
88
89    /// Set the maximum number of input characters per word.
90    #[must_use]
91    pub fn max_input_chars_per_word(mut self, max_input_chars_per_word: usize) -> Self {
92        self.config.max_input_chars_per_word = max_input_chars_per_word;
93        self
94    }
95
96    /// Contructs a `WordPiece` model that uses the `WordPieceBuilder`'s configuration.
97    pub fn build(mut self) -> Result<WordPiece> {
98        if let Some(vocab) = self.config.files {
99            self.config.vocab = WordPiece::read_file(&vocab)?;
100        }
101
102        let vocab_r = self
103            .config
104            .vocab
105            .iter()
106            .map(|(key, val)| (*val, key.to_owned()))
107            .collect();
108
109        Ok(WordPiece {
110            vocab: self.config.vocab,
111            vocab_r,
112            unk_token: self.config.unk_token,
113            continuing_subword_prefix: self.config.continuing_subword_prefix,
114            max_input_chars_per_word: self.config.max_input_chars_per_word,
115        })
116    }
117}
118
119/// A
120/// [WordPiece](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf)
121/// model.
122#[derive(Clone, PartialEq, Eq)]
123pub struct WordPiece {
124    vocab: Vocab,
125    vocab_r: VocabR,
126    pub unk_token: String,
127    pub continuing_subword_prefix: String,
128    pub max_input_chars_per_word: usize,
129}
130
131impl std::fmt::Debug for WordPiece {
132    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
133        fmt.debug_struct("WordPiece")
134            .field("unk_token", &self.unk_token)
135            .field("continuing_subword_prefix", &self.continuing_subword_prefix)
136            .field("max_input_chars_per_word", &self.max_input_chars_per_word)
137            .field("vocab", &self.vocab.len())
138            .finish()
139    }
140}
141
142impl Default for WordPiece {
143    fn default() -> Self {
144        Self {
145            vocab: HashMap::new(),
146            vocab_r: HashMap::new(),
147            unk_token: String::from("[UNK]"),
148            continuing_subword_prefix: String::from("##"),
149            max_input_chars_per_word: 100,
150        }
151    }
152}
153
154impl WordPiece {
155    /// Get a `WordPieceBuilder`.
156    pub fn builder() -> WordPieceBuilder {
157        WordPieceBuilder::new()
158    }
159
160    /// Read the given files to extract the vocab
161    pub fn read_file(vocab: &str) -> Result<Vocab> {
162        let file = File::open(vocab)?;
163        let file = BufReader::new(file);
164
165        let mut vocab = HashMap::new();
166        for (index, line) in file.lines().enumerate() {
167            let line = line?;
168            vocab.insert(line.trim_end().to_owned(), index as u32);
169        }
170
171        Ok(vocab)
172    }
173
174    /// Initialize a `WordPiece` model from a vocab mapping file.
175    pub fn from_file(vocab: &str) -> WordPieceBuilder {
176        WordPiece::builder().files(vocab.to_owned())
177    }
178
179    /// Create a `WordPiece` model from a `BPE` model.
180    pub fn from_bpe(bpe: &BPE) -> Self {
181        let mut wp = Self::builder().vocab(bpe.get_vocab()).build().unwrap();
182        if let Some(unk) = bpe.get_unk_token() {
183            unk.clone_into(&mut wp.unk_token);
184        }
185        if let Some(prefix) = bpe.get_continuing_subword_prefix() {
186            prefix.clone_into(&mut wp.continuing_subword_prefix);
187        }
188        wp
189    }
190}
191
192impl Model for WordPiece {
193    type Trainer = WordPieceTrainer;
194
195    fn get_vocab(&self) -> HashMap<String, u32> {
196        self.vocab.clone()
197    }
198
199    fn get_vocab_size(&self) -> usize {
200        self.vocab.len()
201    }
202
203    fn tokenize(&self, sequence: &str) -> Result<Vec<Token>> {
204        let char_len = sequence.chars().count();
205        if char_len > self.max_input_chars_per_word {
206            return Ok(vec![Token {
207                value: self.unk_token.clone(),
208                id: *self
209                    .vocab
210                    .get(&self.unk_token)
211                    .ok_or(Error::MissingUnkToken)?,
212                offsets: (0, sequence.len()),
213            }]);
214        }
215
216        let mut is_bad = false;
217        let mut start = 0;
218        let mut sub_tokens: Vec<Token> = vec![];
219
220        while start < sequence.len() {
221            let mut end = sequence.len();
222            let mut cur_str = None;
223
224            while start < end {
225                let mut substr: Cow<str> = Cow::Borrowed(&sequence[start..end]);
226
227                if start > 0 {
228                    substr = Cow::Owned(format!("{}{}", self.continuing_subword_prefix, substr));
229                }
230                if self.vocab.contains_key(substr.as_ref()) {
231                    cur_str = Some(Token {
232                        id: self.vocab[substr.as_ref()],
233                        value: substr.to_string(),
234                        offsets: (start, end),
235                    });
236                    break;
237                }
238                end -= substr.chars().last().map_or(1, |c| c.len_utf8());
239            }
240
241            if cur_str.is_none() {
242                is_bad = true;
243                break;
244            }
245
246            sub_tokens.push(cur_str.unwrap());
247            start = end;
248        }
249
250        if is_bad {
251            Ok(vec![Token {
252                value: self.unk_token.clone(),
253                id: *self
254                    .vocab
255                    .get(&self.unk_token)
256                    .ok_or(Error::MissingUnkToken)?,
257                offsets: (0, sequence.len()),
258            }])
259        } else {
260            Ok(sub_tokens)
261        }
262    }
263
264    fn token_to_id(&self, token: &str) -> Option<u32> {
265        self.vocab.get(token).copied()
266    }
267
268    fn id_to_token(&self, id: u32) -> Option<String> {
269        self.vocab_r.get(&id).cloned()
270    }
271
272    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
273        let vocab_file_name = match name {
274            Some(name) => format!("{name}-vocab.txt"),
275            None => "vocab.txt".to_string(),
276        };
277
278        // Write vocab.txt
279        let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
280            .iter()
281            .collect();
282        let mut vocab_file = File::create(&vocab_path)?;
283        let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
284        vocab.sort_unstable_by_key(|k| *k.1);
285        vocab_file.write_all(
286            &vocab
287                .into_iter()
288                .flat_map(|(token, _)| format!("{token}\n").as_bytes().to_owned())
289                .collect::<Vec<_>>()[..],
290        )?;
291
292        Ok(vec![vocab_path])
293    }
294
295    fn get_trainer(&self) -> Self::Trainer {
296        WordPieceTrainer::builder().build()
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_error_display() {
306        assert!(format!("{}", Error::MissingUnkToken).contains("Missing [UNK] token"));
307    }
308}