tokenizers/models/wordlevel/
mod.rs

1use super::OrderedVocabIter;
2use crate::tokenizer::{Model, Result, Token};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufReader, Read, Write};
7use std::path::{Path, PathBuf};
8
9mod serialization;
10mod trainer;
11
12// Re-export
13pub use trainer::*;
14
15type Vocab = HashMap<String, u32>;
16
17#[derive(thiserror::Error, Debug)]
18pub enum Error {
19    #[error("WordLevel error: Missing [UNK] token from the vocabulary")]
20    MissingUnkToken,
21    #[error("Bad vocabulary json file")]
22    BadVocabulary,
23}
24
25struct Config {
26    files: Option<String>,
27    vocab: HashMap<String, u32>,
28    unk_token: String,
29}
30
31/// A `WordLevelBuilder` can be used to create a `WordLevel`
32/// model with a custom configuration.
33pub struct WordLevelBuilder {
34    config: Config,
35}
36
37impl Default for WordLevelBuilder {
38    fn default() -> Self {
39        Self {
40            config: Config {
41                files: None,
42                vocab: HashMap::new(),
43                unk_token: String::from("<unk>"),
44            },
45        }
46    }
47}
48
49impl WordLevelBuilder {
50    /// Construct a new `WordLevelBuilder`.
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Set the input files.
56    #[must_use]
57    pub fn files(mut self, vocab: String) -> Self {
58        self.config.files = Some(vocab);
59        self
60    }
61
62    /// Set the vocab (token -> ID) mapping.
63    #[must_use]
64    pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
65        self.config.vocab = vocab;
66        self
67    }
68
69    /// The the `UNK` token for the vocab.
70    #[must_use]
71    pub fn unk_token(mut self, unk_token: String) -> Self {
72        self.config.unk_token = unk_token;
73        self
74    }
75
76    /// Contructs a `WordLevel` model that uses the `WordLevelBuilder`'s configuration.
77    pub fn build(mut self) -> Result<WordLevel> {
78        if let Some(vocab) = self.config.files {
79            self.config.vocab = WordLevel::read_file(&vocab)?;
80        }
81
82        let vocab_r = self
83            .config
84            .vocab
85            .iter()
86            .map(|(key, val)| (*val, key.to_owned()))
87            .collect();
88
89        Ok(WordLevel {
90            vocab: self.config.vocab,
91            vocab_r,
92            unk_token: self.config.unk_token,
93        })
94    }
95}
96
97#[derive(PartialEq, Clone, Eq)]
98pub struct WordLevel {
99    vocab: HashMap<String, u32>,
100    vocab_r: HashMap<u32, String>,
101    pub unk_token: String,
102}
103
104impl std::fmt::Debug for WordLevel {
105    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
106        fmt.debug_struct("WordLevel")
107            .field("unk_token", &self.unk_token)
108            .field("vocab", &self.vocab.len())
109            .finish()
110    }
111}
112
113impl WordLevel {
114    pub fn builder() -> WordLevelBuilder {
115        WordLevelBuilder::new()
116    }
117
118    pub fn read_file(vocab_path: &str) -> Result<Vocab> {
119        let vocab_file = File::open(vocab_path)?;
120        let mut vocab_file = BufReader::new(vocab_file);
121        let mut buffer = String::new();
122        let mut vocab = HashMap::new();
123
124        vocab_file.read_to_string(&mut buffer)?;
125        let json: Value = serde_json::from_str(&buffer)?;
126
127        match json {
128            Value::Object(m) => {
129                for (token, id) in m {
130                    if let Value::Number(id) = id {
131                        let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
132                        vocab.insert(token, id);
133                    }
134                }
135            }
136            _ => return Err(Box::new(Error::BadVocabulary)),
137        };
138        Ok(vocab)
139    }
140
141    /// Initialize a WordLevel model from vocab and merges file.
142    pub fn from_file(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
143        let vocab = WordLevel::read_file(vocab_path)?;
144        Self::builder().vocab(vocab).unk_token(unk_token).build()
145    }
146}
147
148impl Default for WordLevel {
149    fn default() -> Self {
150        Self {
151            vocab: HashMap::new(),
152            vocab_r: HashMap::new(),
153            unk_token: String::from("<unk>"),
154        }
155    }
156}
157
158impl Model for WordLevel {
159    type Trainer = WordLevelTrainer;
160
161    fn tokenize(&self, token: &str) -> Result<Vec<Token>> {
162        if let Some(&id) = self.vocab.get(token) {
163            Ok(vec![Token {
164                id,
165                value: token.to_owned(),
166                offsets: (0, token.len()),
167            }])
168        } else if let Some(&unk_id) = self.vocab.get(&self.unk_token) {
169            Ok(vec![Token {
170                id: unk_id,
171                value: self.unk_token.to_owned(),
172                offsets: (0, token.len()),
173            }])
174        } else {
175            Err(Box::new(Error::MissingUnkToken))
176        }
177    }
178
179    fn token_to_id(&self, token: &str) -> Option<u32> {
180        self.vocab.get(token).copied()
181    }
182
183    fn id_to_token(&self, id: u32) -> Option<String> {
184        self.vocab_r.get(&id).cloned()
185    }
186
187    fn get_vocab(&self) -> HashMap<String, u32> {
188        self.vocab.clone()
189    }
190
191    fn get_vocab_size(&self) -> usize {
192        self.vocab.keys().len()
193    }
194
195    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
196        let vocab_file_name = match name {
197            Some(name) => format!("{name}-vocab.json"),
198            None => "vocab.json".to_string(),
199        };
200
201        // Write vocab.json
202        let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
203            .iter()
204            .collect();
205        let mut vocab_file = File::create(&vocab_path)?;
206        let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
207        let serialized = serde_json::to_string(&order_vocab_iter)?;
208        vocab_file.write_all(serialized.as_bytes())?;
209
210        Ok(vec![vocab_path])
211    }
212
213    fn get_trainer(&self) -> Self::Trainer {
214        WordLevelTrainer::default()
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_tokenize_unk() {
224        let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
225            .iter()
226            .cloned()
227            .collect();
228        let wordlevel = WordLevelBuilder::default()
229            .vocab(vocab)
230            .unk_token("<unk>".to_string())
231            .build()
232            .unwrap();
233        let tokens = wordlevel.tokenize("c").unwrap();
234        assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
235
236        let tokens = wordlevel.tokenize("a").unwrap();
237        assert_eq!(tokens, vec![Token::new(1u32, "a".into(), (0, 1)),]);
238    }
239
240    #[test]
241    fn test_tokenize_missing_unk_token() {
242        let vocab: Vocab = [("a".into(), 0), ("b".into(), 1)].iter().cloned().collect();
243        let wordlevel = WordLevelBuilder::default().vocab(vocab).build().unwrap();
244        let tokens = wordlevel.tokenize("a").unwrap();
245        assert_eq!(tokens, vec![Token::new(0u32, "a".into(), (0, 1)),]);
246
247        let error = wordlevel.tokenize("c").err().unwrap();
248        assert!(error.is::<Error>());
249    }
250}