tokenizers/models/
mod.rs

1//! Popular tokenizer models.
2
3pub mod bpe;
4pub mod unigram;
5pub mod wordlevel;
6pub mod wordpiece;
7
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13use crate::models::bpe::{BpeTrainer, BPE};
14use crate::models::unigram::{Unigram, UnigramTrainer};
15use crate::models::wordlevel::{WordLevel, WordLevelTrainer};
16use crate::models::wordpiece::{WordPiece, WordPieceTrainer};
17use crate::{AddedToken, Model, Result, Token, Trainer};
18
19/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
20/// of token ID, smallest to largest.
21struct OrderedVocabIter<'a> {
22    vocab_r: &'a HashMap<u32, String>,
23}
24
25impl<'a> OrderedVocabIter<'a> {
26    fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
27        Self { vocab_r }
28    }
29}
30
31impl<'a> Serialize for OrderedVocabIter<'a> {
32    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
33    where
34        S: Serializer,
35    {
36        // There could be holes so max + 1 is more correct than vocab_r.len()
37        let mut holes = vec![];
38        let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() {
39            let iter = (0..*max + 1).filter_map(|i| {
40                if let Some(token) = self.vocab_r.get(&i) {
41                    Some((token, i))
42                } else {
43                    holes.push(i);
44                    None
45                }
46            });
47            serializer.collect_map(iter)
48        } else {
49            serializer.collect_map(std::iter::empty::<(&str, u32)>())
50        };
51
52        if !holes.is_empty() {
53            warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
54            println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
55        }
56        result
57    }
58}
59
60#[derive(Serialize, Debug, PartialEq, Clone)]
61#[serde(untagged)]
62pub enum ModelWrapper {
63    BPE(BPE),
64    // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
65    // with the versions not including the "type"), since WordLevel is a subset of WordPiece
66    WordPiece(WordPiece),
67    WordLevel(WordLevel),
68    Unigram(Unigram),
69}
70
71impl<'de> Deserialize<'de> for ModelWrapper {
72    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
73    where
74        D: Deserializer<'de>,
75    {
76        #[derive(Deserialize)]
77        pub struct Tagged {
78            #[serde(rename = "type")]
79            variant: EnumType,
80            #[serde(flatten)]
81            rest: serde_json::Value,
82        }
83        #[derive(Deserialize)]
84        pub enum EnumType {
85            BPE,
86            WordPiece,
87            WordLevel,
88            Unigram,
89        }
90
91        #[derive(Deserialize)]
92        #[serde(untagged)]
93        pub enum ModelHelper {
94            Tagged(Tagged),
95            Legacy(serde_json::Value),
96        }
97
98        #[derive(Deserialize)]
99        #[serde(untagged)]
100        pub enum ModelUntagged {
101            BPE(BPE),
102            // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
103            // with the versions not including the "type"), since WordLevel is a subset of WordPiece
104            WordPiece(WordPiece),
105            WordLevel(WordLevel),
106            Unigram(Unigram),
107        }
108
109        let helper = ModelHelper::deserialize(deserializer)?;
110        Ok(match helper {
111            ModelHelper::Tagged(model) => match model.variant {
112                EnumType::BPE => ModelWrapper::BPE(
113                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
114                ),
115                EnumType::WordPiece => ModelWrapper::WordPiece(
116                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
117                ),
118                EnumType::WordLevel => ModelWrapper::WordLevel(
119                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
120                ),
121                EnumType::Unigram => ModelWrapper::Unigram(
122                    serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
123                ),
124            },
125            ModelHelper::Legacy(value) => {
126                let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
127                match untagged {
128                    ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
129                    ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
130                    ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
131                    ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
132                }
133            }
134        })
135    }
136}
137
138impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
139impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
140impl_enum_from!(BPE, ModelWrapper, BPE);
141impl_enum_from!(Unigram, ModelWrapper, Unigram);
142
143impl Model for ModelWrapper {
144    type Trainer = TrainerWrapper;
145
146    fn tokenize(&self, tokens: &str) -> Result<Vec<Token>> {
147        match self {
148            Self::WordLevel(t) => t.tokenize(tokens),
149            Self::WordPiece(t) => t.tokenize(tokens),
150            Self::BPE(t) => t.tokenize(tokens),
151            Self::Unigram(t) => t.tokenize(tokens),
152        }
153    }
154
155    fn token_to_id(&self, token: &str) -> Option<u32> {
156        match self {
157            Self::WordLevel(t) => t.token_to_id(token),
158            Self::WordPiece(t) => t.token_to_id(token),
159            Self::BPE(t) => t.token_to_id(token),
160            Self::Unigram(t) => t.token_to_id(token),
161        }
162    }
163
164    fn id_to_token(&self, id: u32) -> Option<String> {
165        match self {
166            Self::WordLevel(t) => t.id_to_token(id),
167            Self::WordPiece(t) => t.id_to_token(id),
168            Self::BPE(t) => t.id_to_token(id),
169            Self::Unigram(t) => t.id_to_token(id),
170        }
171    }
172
173    fn get_vocab(&self) -> HashMap<String, u32> {
174        match self {
175            Self::WordLevel(t) => t.get_vocab(),
176            Self::WordPiece(t) => t.get_vocab(),
177            Self::BPE(t) => t.get_vocab(),
178            Self::Unigram(t) => t.get_vocab(),
179        }
180    }
181
182    fn get_vocab_size(&self) -> usize {
183        match self {
184            Self::WordLevel(t) => t.get_vocab_size(),
185            Self::WordPiece(t) => t.get_vocab_size(),
186            Self::BPE(t) => t.get_vocab_size(),
187            Self::Unigram(t) => t.get_vocab_size(),
188        }
189    }
190
191    fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
192        match self {
193            Self::WordLevel(t) => t.save(folder, name),
194            Self::WordPiece(t) => t.save(folder, name),
195            Self::BPE(t) => t.save(folder, name),
196            Self::Unigram(t) => t.save(folder, name),
197        }
198    }
199
200    fn get_trainer(&self) -> Self::Trainer {
201        match self {
202            Self::WordLevel(t) => t.get_trainer().into(),
203            Self::WordPiece(t) => t.get_trainer().into(),
204            Self::BPE(t) => t.get_trainer().into(),
205            Self::Unigram(t) => t.get_trainer().into(),
206        }
207    }
208}
209
210impl ModelWrapper {
211    pub fn clear_cache(&mut self) {
212        match self {
213            Self::Unigram(model) => model.clear_cache(),
214            Self::BPE(model) => model.clear_cache(),
215            _ => (),
216        }
217    }
218    pub fn resize_cache(&mut self, capacity: usize) {
219        match self {
220            Self::Unigram(model) => model.resize_cache(capacity),
221            Self::BPE(model) => model.resize_cache(capacity),
222            _ => (),
223        }
224    }
225}
226
227#[derive(Clone, Serialize, Deserialize)]
228pub enum TrainerWrapper {
229    BpeTrainer(BpeTrainer),
230    WordPieceTrainer(WordPieceTrainer),
231    WordLevelTrainer(WordLevelTrainer),
232    UnigramTrainer(UnigramTrainer),
233}
234
235impl Trainer for TrainerWrapper {
236    type Model = ModelWrapper;
237
238    fn should_show_progress(&self) -> bool {
239        match self {
240            Self::BpeTrainer(bpe) => bpe.should_show_progress(),
241            Self::WordPieceTrainer(wpt) => wpt.should_show_progress(),
242            Self::WordLevelTrainer(wpt) => wpt.should_show_progress(),
243            Self::UnigramTrainer(wpt) => wpt.should_show_progress(),
244        }
245    }
246
247    fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
248        match self {
249            Self::BpeTrainer(t) => match model {
250                ModelWrapper::BPE(bpe) => t.train(bpe),
251                _ => Err("BpeTrainer can only train a BPE".into()),
252            },
253            Self::WordPieceTrainer(t) => match model {
254                ModelWrapper::WordPiece(wp) => t.train(wp),
255                _ => Err("WordPieceTrainer can only train a WordPiece".into()),
256            },
257            Self::WordLevelTrainer(t) => match model {
258                ModelWrapper::WordLevel(wl) => t.train(wl),
259                _ => Err("WordLevelTrainer can only train a WordLevel".into()),
260            },
261            Self::UnigramTrainer(t) => match model {
262                ModelWrapper::Unigram(u) => t.train(u),
263                _ => Err("UnigramTrainer can only train a Unigram".into()),
264            },
265        }
266    }
267
268    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
269    where
270        I: Iterator<Item = S> + Send,
271        S: AsRef<str> + Send,
272        F: Fn(&str) -> Result<Vec<String>> + Sync,
273    {
274        match self {
275            Self::BpeTrainer(bpe) => bpe.feed(iterator, process),
276            Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
277            Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
278            Self::UnigramTrainer(wpt) => wpt.feed(iterator, process),
279        }
280    }
281}
282
283impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer);
284impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer);
285impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer);
286impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::models::bpe::{BpeBuilder, Vocab};
292
293    #[test]
294    fn trainer_wrapper_train_model_wrapper() {
295        let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
296        let mut model = ModelWrapper::Unigram(Unigram::default());
297
298        let result = trainer.train(&mut model);
299        assert!(result.is_err());
300    }
301
302    #[test]
303    fn incomplete_ordered_vocab() {
304        let vocab_r: HashMap<u32, String> =
305            HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
306
307        let ordered = OrderedVocabIter::new(&vocab_r);
308
309        let serialized = serde_json::to_string(&ordered).unwrap();
310        assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
311    }
312
313    #[test]
314    fn serialization() {
315        let vocab: Vocab = [
316            ("<unk>".into(), 0),
317            ("a".into(), 1),
318            ("b".into(), 2),
319            ("ab".into(), 3),
320        ]
321        .iter()
322        .cloned()
323        .collect();
324        let bpe = BpeBuilder::default()
325            .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
326            .unk_token("<unk>".to_string())
327            .ignore_merges(true)
328            .build()
329            .unwrap();
330
331        let model = ModelWrapper::BPE(bpe);
332        let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
333        let legacy = serde_json::from_str(legacy).unwrap();
334        assert_eq!(model, legacy);
335
336        let data = serde_json::to_string(&model).unwrap();
337        assert_eq!(
338            data,
339            r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
340        );
341        let reconstructed = serde_json::from_str(&data).unwrap();
342        assert_eq!(model, reconstructed);
343
344        // Legacy check, type is not necessary.
345        let legacy = r#"{"dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
346        let reconstructed = serde_json::from_str(legacy).unwrap();
347        assert_eq!(model, reconstructed);
348
349        let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
350        let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
351            serde_json::from_str(invalid);
352        match reconstructed {
353            Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
354            _ => panic!("Expected an error here"),
355        }
356    }
357}