tokenizers/models/wordlevel/
mod.rs1use super::OrderedVocabIter;
2use crate::tokenizer::{Model, Result, Token};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufReader, Read, Write};
7use std::path::{Path, PathBuf};
8
9mod serialization;
10mod trainer;
11
12pub use trainer::*;
14
15type Vocab = HashMap<String, u32>;
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19 #[error("WordLevel error: Missing [UNK] token from the vocabulary")]
20 MissingUnkToken,
21 #[error("Bad vocabulary json file")]
22 BadVocabulary,
23}
24
25struct Config {
26 files: Option<String>,
27 vocab: HashMap<String, u32>,
28 unk_token: String,
29}
30
31pub struct WordLevelBuilder {
34 config: Config,
35}
36
37impl Default for WordLevelBuilder {
38 fn default() -> Self {
39 Self {
40 config: Config {
41 files: None,
42 vocab: HashMap::new(),
43 unk_token: String::from("<unk>"),
44 },
45 }
46 }
47}
48
49impl WordLevelBuilder {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 #[must_use]
57 pub fn files(mut self, vocab: String) -> Self {
58 self.config.files = Some(vocab);
59 self
60 }
61
62 #[must_use]
64 pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
65 self.config.vocab = vocab;
66 self
67 }
68
69 #[must_use]
71 pub fn unk_token(mut self, unk_token: String) -> Self {
72 self.config.unk_token = unk_token;
73 self
74 }
75
76 pub fn build(mut self) -> Result<WordLevel> {
78 if let Some(vocab) = self.config.files {
79 self.config.vocab = WordLevel::read_file(&vocab)?;
80 }
81
82 let vocab_r = self
83 .config
84 .vocab
85 .iter()
86 .map(|(key, val)| (*val, key.to_owned()))
87 .collect();
88
89 Ok(WordLevel {
90 vocab: self.config.vocab,
91 vocab_r,
92 unk_token: self.config.unk_token,
93 })
94 }
95}
96
97#[derive(PartialEq, Clone, Eq)]
98pub struct WordLevel {
99 vocab: HashMap<String, u32>,
100 vocab_r: HashMap<u32, String>,
101 pub unk_token: String,
102}
103
104impl std::fmt::Debug for WordLevel {
105 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
106 fmt.debug_struct("WordLevel")
107 .field("unk_token", &self.unk_token)
108 .field("vocab", &self.vocab.len())
109 .finish()
110 }
111}
112
113impl WordLevel {
114 pub fn builder() -> WordLevelBuilder {
115 WordLevelBuilder::new()
116 }
117
118 pub fn read_file(vocab_path: &str) -> Result<Vocab> {
119 let vocab_file = File::open(vocab_path)?;
120 let mut vocab_file = BufReader::new(vocab_file);
121 let mut buffer = String::new();
122 let mut vocab = HashMap::new();
123
124 vocab_file.read_to_string(&mut buffer)?;
125 let json: Value = serde_json::from_str(&buffer)?;
126
127 match json {
128 Value::Object(m) => {
129 for (token, id) in m {
130 if let Value::Number(id) = id {
131 let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
132 vocab.insert(token, id);
133 }
134 }
135 }
136 _ => return Err(Box::new(Error::BadVocabulary)),
137 };
138 Ok(vocab)
139 }
140
141 pub fn from_file(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
143 let vocab = WordLevel::read_file(vocab_path)?;
144 Self::builder().vocab(vocab).unk_token(unk_token).build()
145 }
146}
147
148impl Default for WordLevel {
149 fn default() -> Self {
150 Self {
151 vocab: HashMap::new(),
152 vocab_r: HashMap::new(),
153 unk_token: String::from("<unk>"),
154 }
155 }
156}
157
158impl Model for WordLevel {
159 type Trainer = WordLevelTrainer;
160
161 fn tokenize(&self, token: &str) -> Result<Vec<Token>> {
162 if let Some(&id) = self.vocab.get(token) {
163 Ok(vec![Token {
164 id,
165 value: token.to_owned(),
166 offsets: (0, token.len()),
167 }])
168 } else if let Some(&unk_id) = self.vocab.get(&self.unk_token) {
169 Ok(vec![Token {
170 id: unk_id,
171 value: self.unk_token.to_owned(),
172 offsets: (0, token.len()),
173 }])
174 } else {
175 Err(Box::new(Error::MissingUnkToken))
176 }
177 }
178
179 fn token_to_id(&self, token: &str) -> Option<u32> {
180 self.vocab.get(token).copied()
181 }
182
183 fn id_to_token(&self, id: u32) -> Option<String> {
184 self.vocab_r.get(&id).cloned()
185 }
186
187 fn get_vocab(&self) -> HashMap<String, u32> {
188 self.vocab.clone()
189 }
190
191 fn get_vocab_size(&self) -> usize {
192 self.vocab.keys().len()
193 }
194
195 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
196 let vocab_file_name = match name {
197 Some(name) => format!("{name}-vocab.json"),
198 None => "vocab.json".to_string(),
199 };
200
201 let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
203 .iter()
204 .collect();
205 let mut vocab_file = File::create(&vocab_path)?;
206 let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
207 let serialized = serde_json::to_string(&order_vocab_iter)?;
208 vocab_file.write_all(serialized.as_bytes())?;
209
210 Ok(vec![vocab_path])
211 }
212
213 fn get_trainer(&self) -> Self::Trainer {
214 WordLevelTrainer::default()
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_tokenize_unk() {
224 let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
225 .iter()
226 .cloned()
227 .collect();
228 let wordlevel = WordLevelBuilder::default()
229 .vocab(vocab)
230 .unk_token("<unk>".to_string())
231 .build()
232 .unwrap();
233 let tokens = wordlevel.tokenize("c").unwrap();
234 assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
235
236 let tokens = wordlevel.tokenize("a").unwrap();
237 assert_eq!(tokens, vec![Token::new(1u32, "a".into(), (0, 1)),]);
238 }
239
240 #[test]
241 fn test_tokenize_missing_unk_token() {
242 let vocab: Vocab = [("a".into(), 0), ("b".into(), 1)].iter().cloned().collect();
243 let wordlevel = WordLevelBuilder::default().vocab(vocab).build().unwrap();
244 let tokens = wordlevel.tokenize("a").unwrap();
245 assert_eq!(tokens, vec![Token::new(0u32, "a".into(), (0, 1)),]);
246
247 let error = wordlevel.tokenize("c").err().unwrap();
248 assert!(error.is::<Error>());
249 }
250}