tokenizers/models/wordpiece/
mod.rs1use crate::models::bpe::BPE;
5use crate::tokenizer::{Model, Result, Token};
6use std::{
7 borrow::Cow,
8 collections::HashMap,
9 fs::File,
10 io::prelude::*,
11 io::{BufRead, BufReader},
12 path::{Path, PathBuf},
13};
14
15mod serialization;
16mod trainer;
17pub use trainer::*;
18
19#[derive(thiserror::Error, Debug)]
20pub enum Error {
21 #[error("WordPiece error: Missing [UNK] token from the vocabulary")]
22 MissingUnkToken,
23}
24
25type Vocab = HashMap<String, u32>;
26type VocabR = HashMap<u32, String>;
27
28struct Config {
29 files: Option<String>,
30 vocab: Vocab,
31 unk_token: String,
32 continuing_subword_prefix: String,
33 max_input_chars_per_word: usize,
34}
35
36pub struct WordPieceBuilder {
38 config: Config,
39}
40
41impl Default for WordPieceBuilder {
42 fn default() -> Self {
43 Self {
44 config: Config {
45 files: None,
46 vocab: HashMap::new(),
47 unk_token: String::from("[UNK]"),
48 continuing_subword_prefix: String::from("##"),
49 max_input_chars_per_word: 100,
50 },
51 }
52 }
53}
54
55impl WordPieceBuilder {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 #[must_use]
63 pub fn files(mut self, vocab: String) -> Self {
64 self.config.files = Some(vocab);
65 self
66 }
67
68 #[must_use]
70 pub fn vocab(mut self, vocab: Vocab) -> Self {
71 self.config.vocab = vocab;
72 self
73 }
74
75 #[must_use]
77 pub fn unk_token(mut self, unk_token: String) -> Self {
78 self.config.unk_token = unk_token;
79 self
80 }
81
82 #[must_use]
84 pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self {
85 self.config.continuing_subword_prefix = continuing_subword_prefix;
86 self
87 }
88
89 #[must_use]
91 pub fn max_input_chars_per_word(mut self, max_input_chars_per_word: usize) -> Self {
92 self.config.max_input_chars_per_word = max_input_chars_per_word;
93 self
94 }
95
96 pub fn build(mut self) -> Result<WordPiece> {
98 if let Some(vocab) = self.config.files {
99 self.config.vocab = WordPiece::read_file(&vocab)?;
100 }
101
102 let vocab_r = self
103 .config
104 .vocab
105 .iter()
106 .map(|(key, val)| (*val, key.to_owned()))
107 .collect();
108
109 Ok(WordPiece {
110 vocab: self.config.vocab,
111 vocab_r,
112 unk_token: self.config.unk_token,
113 continuing_subword_prefix: self.config.continuing_subword_prefix,
114 max_input_chars_per_word: self.config.max_input_chars_per_word,
115 })
116 }
117}
118
119#[derive(Clone, PartialEq, Eq)]
123pub struct WordPiece {
124 vocab: Vocab,
125 vocab_r: VocabR,
126 pub unk_token: String,
127 pub continuing_subword_prefix: String,
128 pub max_input_chars_per_word: usize,
129}
130
131impl std::fmt::Debug for WordPiece {
132 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
133 fmt.debug_struct("WordPiece")
134 .field("unk_token", &self.unk_token)
135 .field("continuing_subword_prefix", &self.continuing_subword_prefix)
136 .field("max_input_chars_per_word", &self.max_input_chars_per_word)
137 .field("vocab", &self.vocab.len())
138 .finish()
139 }
140}
141
142impl Default for WordPiece {
143 fn default() -> Self {
144 Self {
145 vocab: HashMap::new(),
146 vocab_r: HashMap::new(),
147 unk_token: String::from("[UNK]"),
148 continuing_subword_prefix: String::from("##"),
149 max_input_chars_per_word: 100,
150 }
151 }
152}
153
154impl WordPiece {
155 pub fn builder() -> WordPieceBuilder {
157 WordPieceBuilder::new()
158 }
159
160 pub fn read_file(vocab: &str) -> Result<Vocab> {
162 let file = File::open(vocab)?;
163 let file = BufReader::new(file);
164
165 let mut vocab = HashMap::new();
166 for (index, line) in file.lines().enumerate() {
167 let line = line?;
168 vocab.insert(line.trim_end().to_owned(), index as u32);
169 }
170
171 Ok(vocab)
172 }
173
174 pub fn from_file(vocab: &str) -> WordPieceBuilder {
176 WordPiece::builder().files(vocab.to_owned())
177 }
178
179 pub fn from_bpe(bpe: &BPE) -> Self {
181 let mut wp = Self::builder().vocab(bpe.get_vocab()).build().unwrap();
182 if let Some(unk) = bpe.get_unk_token() {
183 unk.clone_into(&mut wp.unk_token);
184 }
185 if let Some(prefix) = bpe.get_continuing_subword_prefix() {
186 prefix.clone_into(&mut wp.continuing_subword_prefix);
187 }
188 wp
189 }
190}
191
192impl Model for WordPiece {
193 type Trainer = WordPieceTrainer;
194
195 fn get_vocab(&self) -> HashMap<String, u32> {
196 self.vocab.clone()
197 }
198
199 fn get_vocab_size(&self) -> usize {
200 self.vocab.len()
201 }
202
203 fn tokenize(&self, sequence: &str) -> Result<Vec<Token>> {
204 let char_len = sequence.chars().count();
205 if char_len > self.max_input_chars_per_word {
206 return Ok(vec![Token {
207 value: self.unk_token.clone(),
208 id: *self
209 .vocab
210 .get(&self.unk_token)
211 .ok_or(Error::MissingUnkToken)?,
212 offsets: (0, sequence.len()),
213 }]);
214 }
215
216 let mut is_bad = false;
217 let mut start = 0;
218 let mut sub_tokens: Vec<Token> = vec![];
219
220 while start < sequence.len() {
221 let mut end = sequence.len();
222 let mut cur_str = None;
223
224 while start < end {
225 let mut substr: Cow<str> = Cow::Borrowed(&sequence[start..end]);
226
227 if start > 0 {
228 substr = Cow::Owned(format!("{}{}", self.continuing_subword_prefix, substr));
229 }
230 if self.vocab.contains_key(substr.as_ref()) {
231 cur_str = Some(Token {
232 id: self.vocab[substr.as_ref()],
233 value: substr.to_string(),
234 offsets: (start, end),
235 });
236 break;
237 }
238 end -= substr.chars().last().map_or(1, |c| c.len_utf8());
239 }
240
241 if cur_str.is_none() {
242 is_bad = true;
243 break;
244 }
245
246 sub_tokens.push(cur_str.unwrap());
247 start = end;
248 }
249
250 if is_bad {
251 Ok(vec![Token {
252 value: self.unk_token.clone(),
253 id: *self
254 .vocab
255 .get(&self.unk_token)
256 .ok_or(Error::MissingUnkToken)?,
257 offsets: (0, sequence.len()),
258 }])
259 } else {
260 Ok(sub_tokens)
261 }
262 }
263
264 fn token_to_id(&self, token: &str) -> Option<u32> {
265 self.vocab.get(token).copied()
266 }
267
268 fn id_to_token(&self, id: u32) -> Option<String> {
269 self.vocab_r.get(&id).cloned()
270 }
271
272 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
273 let vocab_file_name = match name {
274 Some(name) => format!("{name}-vocab.txt"),
275 None => "vocab.txt".to_string(),
276 };
277
278 let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
280 .iter()
281 .collect();
282 let mut vocab_file = File::create(&vocab_path)?;
283 let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
284 vocab.sort_unstable_by_key(|k| *k.1);
285 vocab_file.write_all(
286 &vocab
287 .into_iter()
288 .flat_map(|(token, _)| format!("{token}\n").as_bytes().to_owned())
289 .collect::<Vec<_>>()[..],
290 )?;
291
292 Ok(vec![vocab_path])
293 }
294
295 fn get_trainer(&self) -> Self::Trainer {
296 WordPieceTrainer::builder().build()
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_error_display() {
306 assert!(format!("{}", Error::MissingUnkToken).contains("Missing [UNK] token"));
307 }
308}