1pub mod bpe;
4pub mod unigram;
5pub mod wordlevel;
6pub mod wordpiece;
7
8use ahash::AHashMap;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13
14use crate::models::bpe::{BpeTrainer, BPE};
15use crate::models::unigram::{Unigram, UnigramTrainer};
16use crate::models::wordlevel::{WordLevel, WordLevelTrainer};
17use crate::models::wordpiece::{WordPiece, WordPieceTrainer};
18use crate::{AddedToken, Model, Result, Token, Trainer};
19
20struct OrderedVocabIter<'a> {
23 vocab_r: &'a AHashMap<u32, String>,
24}
25
26impl<'a> OrderedVocabIter<'a> {
27 fn new(vocab_r: &'a AHashMap<u32, String>) -> Self {
28 Self { vocab_r }
29 }
30}
31
32impl Serialize for OrderedVocabIter<'_> {
33 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 let mut holes = vec![];
39 let result = if let Some(max) = self.vocab_r.keys().max() {
40 let iter = (0..*max + 1).filter_map(|i| {
41 if let Some(token) = self.vocab_r.get(&i) {
42 Some((token, i))
43 } else {
44 holes.push(i);
45 None
46 }
47 });
48 serializer.collect_map(iter)
49 } else {
50 serializer.collect_map(std::iter::empty::<(&str, u32)>())
51 };
52
53 if !holes.is_empty() {
54 warn!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
55 println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
56 }
57 result
58 }
59}
60
61#[derive(Serialize, Debug, PartialEq, Clone)]
62#[serde(untagged)]
63pub enum ModelWrapper {
64 BPE(BPE),
65 WordPiece(WordPiece),
68 WordLevel(WordLevel),
69 Unigram(Unigram),
70}
71
72impl<'de> Deserialize<'de> for ModelWrapper {
73 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
74 where
75 D: Deserializer<'de>,
76 {
77 #[derive(Deserialize)]
78 pub struct Tagged {
79 #[serde(rename = "type")]
80 variant: EnumType,
81 #[serde(flatten)]
82 rest: serde_json::Value,
83 }
84 #[derive(Deserialize)]
85 pub enum EnumType {
86 BPE,
87 WordPiece,
88 WordLevel,
89 Unigram,
90 }
91
92 #[derive(Deserialize)]
93 #[serde(untagged)]
94 pub enum ModelHelper {
95 Tagged(Tagged),
96 Legacy(serde_json::Value),
97 }
98
99 #[derive(Deserialize)]
100 #[serde(untagged)]
101 pub enum ModelUntagged {
102 BPE(BPE),
103 WordPiece(WordPiece),
106 WordLevel(WordLevel),
107 Unigram(Unigram),
108 }
109
110 let helper = ModelHelper::deserialize(deserializer)?;
111 Ok(match helper {
112 ModelHelper::Tagged(model) => match model.variant {
113 EnumType::BPE => ModelWrapper::BPE(
114 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
115 ),
116 EnumType::WordPiece => ModelWrapper::WordPiece(
117 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
118 ),
119 EnumType::WordLevel => ModelWrapper::WordLevel(
120 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
121 ),
122 EnumType::Unigram => ModelWrapper::Unigram(
123 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
124 ),
125 },
126 ModelHelper::Legacy(value) => {
127 let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
128 match untagged {
129 ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
130 ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
131 ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
132 ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
133 }
134 }
135 })
136 }
137}
138
139impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
140impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
141impl_enum_from!(BPE, ModelWrapper, BPE);
142impl_enum_from!(Unigram, ModelWrapper, Unigram);
143
144impl Model for ModelWrapper {
145 type Trainer = TrainerWrapper;
146
147 fn tokenize(&self, tokens: &str) -> Result<Vec<Token>> {
148 match self {
149 Self::WordLevel(t) => t.tokenize(tokens),
150 Self::WordPiece(t) => t.tokenize(tokens),
151 Self::BPE(t) => t.tokenize(tokens),
152 Self::Unigram(t) => t.tokenize(tokens),
153 }
154 }
155
156 fn token_to_id(&self, token: &str) -> Option<u32> {
157 match self {
158 Self::WordLevel(t) => t.token_to_id(token),
159 Self::WordPiece(t) => t.token_to_id(token),
160 Self::BPE(t) => t.token_to_id(token),
161 Self::Unigram(t) => t.token_to_id(token),
162 }
163 }
164
165 fn id_to_token(&self, id: u32) -> Option<String> {
166 match self {
167 Self::WordLevel(t) => t.id_to_token(id),
168 Self::WordPiece(t) => t.id_to_token(id),
169 Self::BPE(t) => t.id_to_token(id),
170 Self::Unigram(t) => t.id_to_token(id),
171 }
172 }
173
174 fn get_vocab(&self) -> HashMap<String, u32> {
175 match self {
176 Self::WordLevel(t) => t.get_vocab(),
177 Self::WordPiece(t) => t.get_vocab(),
178 Self::BPE(t) => t.get_vocab(),
179 Self::Unigram(t) => t.get_vocab(),
180 }
181 }
182
183 fn get_vocab_size(&self) -> usize {
184 match self {
185 Self::WordLevel(t) => t.get_vocab_size(),
186 Self::WordPiece(t) => t.get_vocab_size(),
187 Self::BPE(t) => t.get_vocab_size(),
188 Self::Unigram(t) => t.get_vocab_size(),
189 }
190 }
191
192 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
193 match self {
194 Self::WordLevel(t) => t.save(folder, name),
195 Self::WordPiece(t) => t.save(folder, name),
196 Self::BPE(t) => t.save(folder, name),
197 Self::Unigram(t) => t.save(folder, name),
198 }
199 }
200
201 fn get_trainer(&self) -> Self::Trainer {
202 match self {
203 Self::WordLevel(t) => t.get_trainer().into(),
204 Self::WordPiece(t) => t.get_trainer().into(),
205 Self::BPE(t) => t.get_trainer().into(),
206 Self::Unigram(t) => t.get_trainer().into(),
207 }
208 }
209}
210
211impl ModelWrapper {
212 pub fn clear_cache(&mut self) {
213 match self {
214 Self::Unigram(model) => model.clear_cache(),
215 Self::BPE(model) => model.clear_cache(),
216 _ => (),
217 }
218 }
219 pub fn resize_cache(&mut self, capacity: usize) {
220 match self {
221 Self::Unigram(model) => model.resize_cache(capacity),
222 Self::BPE(model) => model.resize_cache(capacity),
223 _ => (),
224 }
225 }
226}
227
228#[derive(Clone, Serialize, Deserialize)]
229pub enum TrainerWrapper {
230 BpeTrainer(BpeTrainer),
231 WordPieceTrainer(WordPieceTrainer),
232 WordLevelTrainer(WordLevelTrainer),
233 UnigramTrainer(UnigramTrainer),
234}
235
236impl Trainer for TrainerWrapper {
237 type Model = ModelWrapper;
238
239 fn should_show_progress(&self) -> bool {
240 match self {
241 Self::BpeTrainer(bpe) => bpe.should_show_progress(),
242 Self::WordPieceTrainer(wpt) => wpt.should_show_progress(),
243 Self::WordLevelTrainer(wpt) => wpt.should_show_progress(),
244 Self::UnigramTrainer(wpt) => wpt.should_show_progress(),
245 }
246 }
247
248 fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
249 match self {
250 Self::BpeTrainer(t) => match model {
251 ModelWrapper::BPE(bpe) => t.train(bpe),
252 _ => Err("BpeTrainer can only train a BPE".into()),
253 },
254 Self::WordPieceTrainer(t) => match model {
255 ModelWrapper::WordPiece(wp) => t.train(wp),
256 _ => Err("WordPieceTrainer can only train a WordPiece".into()),
257 },
258 Self::WordLevelTrainer(t) => match model {
259 ModelWrapper::WordLevel(wl) => t.train(wl),
260 _ => Err("WordLevelTrainer can only train a WordLevel".into()),
261 },
262 Self::UnigramTrainer(t) => match model {
263 ModelWrapper::Unigram(u) => t.train(u),
264 _ => Err("UnigramTrainer can only train a Unigram".into()),
265 },
266 }
267 }
268
269 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
270 where
271 I: Iterator<Item = S> + Send,
272 S: AsRef<str> + Send,
273 F: Fn(&str) -> Result<Vec<String>> + Sync,
274 {
275 match self {
276 Self::BpeTrainer(bpe) => bpe.feed(iterator, process),
277 Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
278 Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
279 Self::UnigramTrainer(wpt) => wpt.feed(iterator, process),
280 }
281 }
282}
283
284impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer);
285impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer);
286impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer);
287impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::models::bpe::{BpeBuilder, Vocab};
293
294 #[test]
295 fn trainer_wrapper_train_model_wrapper() {
296 let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
297 let mut model = ModelWrapper::Unigram(Unigram::default());
298
299 let result = trainer.train(&mut model);
300 assert!(result.is_err());
301 }
302
303 #[test]
304 fn incomplete_ordered_vocab() {
305 let vocab_r: AHashMap<u32, String> =
306 AHashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
307
308 let ordered = OrderedVocabIter::new(&vocab_r);
309
310 let serialized = serde_json::to_string(&ordered).unwrap();
311 assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
312 }
313
314 #[test]
315 fn serialization() {
316 let vocab: Vocab = [
317 ("<unk>".into(), 0),
318 ("a".into(), 1),
319 ("b".into(), 2),
320 ("ab".into(), 3),
321 ]
322 .iter()
323 .cloned()
324 .collect();
325 let bpe = BpeBuilder::default()
326 .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
327 .unk_token("<unk>".to_string())
328 .ignore_merges(true)
329 .build()
330 .unwrap();
331
332 let model = ModelWrapper::BPE(bpe);
333 let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
334 let legacy = serde_json::from_str(legacy).unwrap();
335 assert_eq!(model, legacy);
336
337 let data = serde_json::to_string(&model).unwrap();
338 assert_eq!(
339 data,
340 r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
341 );
342 let reconstructed = serde_json::from_str(&data).unwrap();
343 assert_eq!(model, reconstructed);
344
345 let legacy = r#"{"dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
347 let reconstructed = serde_json::from_str(legacy).unwrap();
348 assert_eq!(model, reconstructed);
349
350 let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
351 let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
352 serde_json::from_str(invalid);
353 match reconstructed {
354 Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
355 _ => panic!("Expected an error here"),
356 }
357 }
358}