tokenizers/models/
mod.rs

1//! Popular tokenizer models.
2
3pub mod bpe;
4pub mod unigram;
5pub mod wordlevel;
6pub mod wordpiece;
7
8use ahash::AHashMap;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13
14use crate::models::bpe::{BpeTrainer, BPE};
15use crate::models::unigram::{Unigram, UnigramTrainer};
16use crate::models::wordlevel::{WordLevel, WordLevelTrainer};
17use crate::models::wordpiece::{WordPiece, WordPieceTrainer};
18use crate::{AddedToken, Model, Result, Token, Trainer};
19
20/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
21/// of token ID, smallest to largest.
22struct OrderedVocabIter<'a> {
23    vocab_r: &'a AHashMap<u32, String>,
24}
25
26impl<'a> OrderedVocabIter<'a> {
27    fn new(vocab_r: &'a AHashMap<u32, String>) -> Self {
28        Self { vocab_r }
29    }
30}
31
32impl Serialize for OrderedVocabIter<'_> {
33    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
34    where
35        S: Serializer,
36    {
37        // There could be holes so max + 1 is more correct than vocab_r.len()
38        let mut holes = vec![];
39        let result = if let Some(max) = self.vocab_r.keys().max() {
40            let iter = (0..*max + 1).filter_map(|i| {
41                if let Some(token) = self.vocab_r.get(&i) {
42                    Some((token, i))
43                } else {
44                    holes.push(i);
45                    None
46                }
47            });
48            serializer.collect_map(iter)
49        } else {
50            serializer.collect_map(std::iter::empty::<(&str, u32)>())
51        };
52
53        if !holes.is_empty() {
54            warn!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
55            println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
56        }
57        result
58    }
59}
60
61#[derive(Serialize, Debug, PartialEq, Clone)]
62#[serde(untagged)]
63pub enum ModelWrapper {
64    BPE(BPE),
65    // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
66    // with the versions not including the "type"), since WordLevel is a subset of WordPiece
67    WordPiece(WordPiece),
68    WordLevel(WordLevel),
69    Unigram(Unigram),
70}
71
72impl<'de> Deserialize<'de> for ModelWrapper {
73    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
74    where
75        D: Deserializer<'de>,
76    {
77        #[derive(Deserialize)]
78        pub struct Tagged {
79            #[serde(rename = "type")]
80            variant: EnumType,
81            #[serde(flatten)]
82            rest: serde_json::Value,
83        }
84        #[derive(Deserialize)]
85        pub enum EnumType {
86            BPE,
87            WordPiece,
88            WordLevel,
89            Unigram,
90        }
91
92        #[derive(Deserialize)]
93        #[serde(untagged)]
94        pub enum ModelHelper {
95            Tagged(Tagged),
96            Legacy(serde_json::Value),
97        }
98
99        #[derive(Deserialize)]
100        #[serde(untagged)]
101        pub enum ModelUntagged {
102            BPE(BPE),
103            // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
104            // with the versions not including the "type"), since WordLevel is a subset of WordPiece
105            WordPiece(WordPiece),
106            WordLevel(WordLevel),
107            Unigram(Unigram),
108        }
109
110        let helper = ModelHelper::deserialize(deserializer)?;
111        Ok(match helper {
112            ModelHelper::Tagged(model) => match model.variant {
113                EnumType::BPE => ModelWrapper::BPE(
114                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
115                ),
116                EnumType::WordPiece => ModelWrapper::WordPiece(
117                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
118                ),
119                EnumType::WordLevel => ModelWrapper::WordLevel(
120                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
121                ),
122                EnumType::Unigram => ModelWrapper::Unigram(
123                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
124                ),
125            },
126            ModelHelper::Legacy(value) => {
127                let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
128                match untagged {
129                    ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
130                    ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
131                    ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
132                    ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
133                }
134            }
135        })
136    }
137}
138
139impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
140impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
141impl_enum_from!(BPE, ModelWrapper, BPE);
142impl_enum_from!(Unigram, ModelWrapper, Unigram);
143
144impl Model for ModelWrapper {
145    type Trainer = TrainerWrapper;
146
147    fn tokenize(&self, tokens: &str) -> Result<Vec<Token>> {
148        match self {
149            Self::WordLevel(t) => t.tokenize(tokens),
150            Self::WordPiece(t) => t.tokenize(tokens),
151            Self::BPE(t) => t.tokenize(tokens),
152            Self::Unigram(t) => t.tokenize(tokens),
153        }
154    }
155
156    fn token_to_id(&self, token: &str) -> Option<u32> {
157        match self {
158            Self::WordLevel(t) => t.token_to_id(token),
159            Self::WordPiece(t) => t.token_to_id(token),
160            Self::BPE(t) => t.token_to_id(token),
161            Self::Unigram(t) => t.token_to_id(token),
162        }
163    }
164
165    fn id_to_token(&self, id: u32) -> Option<String> {
166        match self {
167            Self::WordLevel(t) => t.id_to_token(id),
168            Self::WordPiece(t) => t.id_to_token(id),
169            Self::BPE(t) => t.id_to_token(id),
170            Self::Unigram(t) => t.id_to_token(id),
171        }
172    }
173
174    fn get_vocab(&self) -> HashMap<String, u32> {
175        match self {
176            Self::WordLevel(t) => t.get_vocab(),
177            Self::WordPiece(t) => t.get_vocab(),
178            Self::BPE(t) => t.get_vocab(),
179            Self::Unigram(t) => t.get_vocab(),
180        }
181    }
182
183    fn get_vocab_size(&self) -> usize {
184        match self {
185            Self::WordLevel(t) => t.get_vocab_size(),
186            Self::WordPiece(t) => t.get_vocab_size(),
187            Self::BPE(t) => t.get_vocab_size(),
188            Self::Unigram(t) => t.get_vocab_size(),
189        }
190    }
191
192    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
193        match self {
194            Self::WordLevel(t) => t.save(folder, name),
195            Self::WordPiece(t) => t.save(folder, name),
196            Self::BPE(t) => t.save(folder, name),
197            Self::Unigram(t) => t.save(folder, name),
198        }
199    }
200
201    fn get_trainer(&self) -> Self::Trainer {
202        match self {
203            Self::WordLevel(t) => t.get_trainer().into(),
204            Self::WordPiece(t) => t.get_trainer().into(),
205            Self::BPE(t) => t.get_trainer().into(),
206            Self::Unigram(t) => t.get_trainer().into(),
207        }
208    }
209}
210
211impl ModelWrapper {
212    pub fn clear_cache(&mut self) {
213        match self {
214            Self::Unigram(model) => model.clear_cache(),
215            Self::BPE(model) => model.clear_cache(),
216            _ => (),
217        }
218    }
219    pub fn resize_cache(&mut self, capacity: usize) {
220        match self {
221            Self::Unigram(model) => model.resize_cache(capacity),
222            Self::BPE(model) => model.resize_cache(capacity),
223            _ => (),
224        }
225    }
226}
227
228#[derive(Clone, Serialize, Deserialize)]
229pub enum TrainerWrapper {
230    BpeTrainer(BpeTrainer),
231    WordPieceTrainer(WordPieceTrainer),
232    WordLevelTrainer(WordLevelTrainer),
233    UnigramTrainer(UnigramTrainer),
234}
235
236impl Trainer for TrainerWrapper {
237    type Model = ModelWrapper;
238
239    fn should_show_progress(&self) -> bool {
240        match self {
241            Self::BpeTrainer(bpe) => bpe.should_show_progress(),
242            Self::WordPieceTrainer(wpt) => wpt.should_show_progress(),
243            Self::WordLevelTrainer(wpt) => wpt.should_show_progress(),
244            Self::UnigramTrainer(wpt) => wpt.should_show_progress(),
245        }
246    }
247
248    fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
249        match self {
250            Self::BpeTrainer(t) => match model {
251                ModelWrapper::BPE(bpe) => t.train(bpe),
252                _ => Err("BpeTrainer can only train a BPE".into()),
253            },
254            Self::WordPieceTrainer(t) => match model {
255                ModelWrapper::WordPiece(wp) => t.train(wp),
256                _ => Err("WordPieceTrainer can only train a WordPiece".into()),
257            },
258            Self::WordLevelTrainer(t) => match model {
259                ModelWrapper::WordLevel(wl) => t.train(wl),
260                _ => Err("WordLevelTrainer can only train a WordLevel".into()),
261            },
262            Self::UnigramTrainer(t) => match model {
263                ModelWrapper::Unigram(u) => t.train(u),
264                _ => Err("UnigramTrainer can only train a Unigram".into()),
265            },
266        }
267    }
268
269    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
270    where
271        I: Iterator<Item = S> + Send,
272        S: AsRef<str> + Send,
273        F: Fn(&str) -> Result<Vec<String>> + Sync,
274    {
275        match self {
276            Self::BpeTrainer(bpe) => bpe.feed(iterator, process),
277            Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
278            Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
279            Self::UnigramTrainer(wpt) => wpt.feed(iterator, process),
280        }
281    }
282}
283
284impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer);
285impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer);
286impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer);
287impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::models::bpe::{BpeBuilder, Vocab};
293
294    #[test]
295    fn trainer_wrapper_train_model_wrapper() {
296        let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
297        let mut model = ModelWrapper::Unigram(Unigram::default());
298
299        let result = trainer.train(&mut model);
300        assert!(result.is_err());
301    }
302
303    #[test]
304    fn incomplete_ordered_vocab() {
305        let vocab_r: AHashMap<u32, String> =
306            AHashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
307
308        let ordered = OrderedVocabIter::new(&vocab_r);
309
310        let serialized = serde_json::to_string(&ordered).unwrap();
311        assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
312    }
313
314    #[test]
315    fn serialization() {
316        let vocab: Vocab = [
317            ("<unk>".into(), 0),
318            ("a".into(), 1),
319            ("b".into(), 2),
320            ("ab".into(), 3),
321        ]
322        .iter()
323        .cloned()
324        .collect();
325        let bpe = BpeBuilder::default()
326            .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
327            .unk_token("<unk>".to_string())
328            .ignore_merges(true)
329            .build()
330            .unwrap();
331
332        let model = ModelWrapper::BPE(bpe);
333        let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
334        let legacy = serde_json::from_str(legacy).unwrap();
335        assert_eq!(model, legacy);
336
337        let data = serde_json::to_string(&model).unwrap();
338        assert_eq!(
339            data,
340            r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
341        );
342        let reconstructed = serde_json::from_str(&data).unwrap();
343        assert_eq!(model, reconstructed);
344
345        // Legacy check, type is not necessary.
346        let legacy = r#"{"dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
347        let reconstructed = serde_json::from_str(legacy).unwrap();
348        assert_eq!(model, reconstructed);
349
350        let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
351        let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
352            serde_json::from_str(invalid);
353        match reconstructed {
354            Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
355            _ => panic!("Expected an error here"),
356        }
357    }
358}