1pub mod bpe;
4pub mod unigram;
5pub mod wordlevel;
6pub mod wordpiece;
7
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13use crate::models::bpe::{BpeTrainer, BPE};
14use crate::models::unigram::{Unigram, UnigramTrainer};
15use crate::models::wordlevel::{WordLevel, WordLevelTrainer};
16use crate::models::wordpiece::{WordPiece, WordPieceTrainer};
17use crate::{AddedToken, Model, Result, Token, Trainer};
18
19struct OrderedVocabIter<'a> {
22 vocab_r: &'a HashMap<u32, String>,
23}
24
25impl<'a> OrderedVocabIter<'a> {
26 fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
27 Self { vocab_r }
28 }
29}
30
31impl<'a> Serialize for OrderedVocabIter<'a> {
32 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
33 where
34 S: Serializer,
35 {
36 let mut holes = vec![];
38 let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() {
39 let iter = (0..*max + 1).filter_map(|i| {
40 if let Some(token) = self.vocab_r.get(&i) {
41 Some((token, i))
42 } else {
43 holes.push(i);
44 None
45 }
46 });
47 serializer.collect_map(iter)
48 } else {
49 serializer.collect_map(std::iter::empty::<(&str, u32)>())
50 };
51
52 if !holes.is_empty() {
53 warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes);
54 println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
55 }
56 result
57 }
58}
59
60#[derive(Serialize, Debug, PartialEq, Clone)]
61#[serde(untagged)]
62pub enum ModelWrapper {
63 BPE(BPE),
64 WordPiece(WordPiece),
67 WordLevel(WordLevel),
68 Unigram(Unigram),
69}
70
71impl<'de> Deserialize<'de> for ModelWrapper {
72 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
73 where
74 D: Deserializer<'de>,
75 {
76 #[derive(Deserialize)]
77 pub struct Tagged {
78 #[serde(rename = "type")]
79 variant: EnumType,
80 #[serde(flatten)]
81 rest: serde_json::Value,
82 }
83 #[derive(Deserialize)]
84 pub enum EnumType {
85 BPE,
86 WordPiece,
87 WordLevel,
88 Unigram,
89 }
90
91 #[derive(Deserialize)]
92 #[serde(untagged)]
93 pub enum ModelHelper {
94 Tagged(Tagged),
95 Legacy(serde_json::Value),
96 }
97
98 #[derive(Deserialize)]
99 #[serde(untagged)]
100 pub enum ModelUntagged {
101 BPE(BPE),
102 WordPiece(WordPiece),
105 WordLevel(WordLevel),
106 Unigram(Unigram),
107 }
108
109 let helper = ModelHelper::deserialize(deserializer)?;
110 Ok(match helper {
111 ModelHelper::Tagged(model) => match model.variant {
112 EnumType::BPE => ModelWrapper::BPE(
113 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
114 ),
115 EnumType::WordPiece => ModelWrapper::WordPiece(
116 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
117 ),
118 EnumType::WordLevel => ModelWrapper::WordLevel(
119 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
120 ),
121 EnumType::Unigram => ModelWrapper::Unigram(
122 serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
123 ),
124 },
125 ModelHelper::Legacy(value) => {
126 let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
127 match untagged {
128 ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
129 ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
130 ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
131 ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
132 }
133 }
134 })
135 }
136}
137
138impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
139impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
140impl_enum_from!(BPE, ModelWrapper, BPE);
141impl_enum_from!(Unigram, ModelWrapper, Unigram);
142
143impl Model for ModelWrapper {
144 type Trainer = TrainerWrapper;
145
146 fn tokenize(&self, tokens: &str) -> Result<Vec<Token>> {
147 match self {
148 Self::WordLevel(t) => t.tokenize(tokens),
149 Self::WordPiece(t) => t.tokenize(tokens),
150 Self::BPE(t) => t.tokenize(tokens),
151 Self::Unigram(t) => t.tokenize(tokens),
152 }
153 }
154
155 fn token_to_id(&self, token: &str) -> Option<u32> {
156 match self {
157 Self::WordLevel(t) => t.token_to_id(token),
158 Self::WordPiece(t) => t.token_to_id(token),
159 Self::BPE(t) => t.token_to_id(token),
160 Self::Unigram(t) => t.token_to_id(token),
161 }
162 }
163
164 fn id_to_token(&self, id: u32) -> Option<String> {
165 match self {
166 Self::WordLevel(t) => t.id_to_token(id),
167 Self::WordPiece(t) => t.id_to_token(id),
168 Self::BPE(t) => t.id_to_token(id),
169 Self::Unigram(t) => t.id_to_token(id),
170 }
171 }
172
173 fn get_vocab(&self) -> HashMap<String, u32> {
174 match self {
175 Self::WordLevel(t) => t.get_vocab(),
176 Self::WordPiece(t) => t.get_vocab(),
177 Self::BPE(t) => t.get_vocab(),
178 Self::Unigram(t) => t.get_vocab(),
179 }
180 }
181
182 fn get_vocab_size(&self) -> usize {
183 match self {
184 Self::WordLevel(t) => t.get_vocab_size(),
185 Self::WordPiece(t) => t.get_vocab_size(),
186 Self::BPE(t) => t.get_vocab_size(),
187 Self::Unigram(t) => t.get_vocab_size(),
188 }
189 }
190
191 fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
192 match self {
193 Self::WordLevel(t) => t.save(folder, name),
194 Self::WordPiece(t) => t.save(folder, name),
195 Self::BPE(t) => t.save(folder, name),
196 Self::Unigram(t) => t.save(folder, name),
197 }
198 }
199
200 fn get_trainer(&self) -> Self::Trainer {
201 match self {
202 Self::WordLevel(t) => t.get_trainer().into(),
203 Self::WordPiece(t) => t.get_trainer().into(),
204 Self::BPE(t) => t.get_trainer().into(),
205 Self::Unigram(t) => t.get_trainer().into(),
206 }
207 }
208}
209
210impl ModelWrapper {
211 pub fn clear_cache(&mut self) {
212 match self {
213 Self::Unigram(model) => model.clear_cache(),
214 Self::BPE(model) => model.clear_cache(),
215 _ => (),
216 }
217 }
218 pub fn resize_cache(&mut self, capacity: usize) {
219 match self {
220 Self::Unigram(model) => model.resize_cache(capacity),
221 Self::BPE(model) => model.resize_cache(capacity),
222 _ => (),
223 }
224 }
225}
226
227#[derive(Clone, Serialize, Deserialize)]
228pub enum TrainerWrapper {
229 BpeTrainer(BpeTrainer),
230 WordPieceTrainer(WordPieceTrainer),
231 WordLevelTrainer(WordLevelTrainer),
232 UnigramTrainer(UnigramTrainer),
233}
234
235impl Trainer for TrainerWrapper {
236 type Model = ModelWrapper;
237
238 fn should_show_progress(&self) -> bool {
239 match self {
240 Self::BpeTrainer(bpe) => bpe.should_show_progress(),
241 Self::WordPieceTrainer(wpt) => wpt.should_show_progress(),
242 Self::WordLevelTrainer(wpt) => wpt.should_show_progress(),
243 Self::UnigramTrainer(wpt) => wpt.should_show_progress(),
244 }
245 }
246
247 fn train(&self, model: &mut ModelWrapper) -> Result<Vec<AddedToken>> {
248 match self {
249 Self::BpeTrainer(t) => match model {
250 ModelWrapper::BPE(bpe) => t.train(bpe),
251 _ => Err("BpeTrainer can only train a BPE".into()),
252 },
253 Self::WordPieceTrainer(t) => match model {
254 ModelWrapper::WordPiece(wp) => t.train(wp),
255 _ => Err("WordPieceTrainer can only train a WordPiece".into()),
256 },
257 Self::WordLevelTrainer(t) => match model {
258 ModelWrapper::WordLevel(wl) => t.train(wl),
259 _ => Err("WordLevelTrainer can only train a WordLevel".into()),
260 },
261 Self::UnigramTrainer(t) => match model {
262 ModelWrapper::Unigram(u) => t.train(u),
263 _ => Err("UnigramTrainer can only train a Unigram".into()),
264 },
265 }
266 }
267
268 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
269 where
270 I: Iterator<Item = S> + Send,
271 S: AsRef<str> + Send,
272 F: Fn(&str) -> Result<Vec<String>> + Sync,
273 {
274 match self {
275 Self::BpeTrainer(bpe) => bpe.feed(iterator, process),
276 Self::WordPieceTrainer(wpt) => wpt.feed(iterator, process),
277 Self::WordLevelTrainer(wpt) => wpt.feed(iterator, process),
278 Self::UnigramTrainer(wpt) => wpt.feed(iterator, process),
279 }
280 }
281}
282
283impl_enum_from!(BpeTrainer, TrainerWrapper, BpeTrainer);
284impl_enum_from!(WordPieceTrainer, TrainerWrapper, WordPieceTrainer);
285impl_enum_from!(UnigramTrainer, TrainerWrapper, UnigramTrainer);
286impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::models::bpe::{BpeBuilder, Vocab};
292
293 #[test]
294 fn trainer_wrapper_train_model_wrapper() {
295 let trainer = TrainerWrapper::BpeTrainer(BpeTrainer::default());
296 let mut model = ModelWrapper::Unigram(Unigram::default());
297
298 let result = trainer.train(&mut model);
299 assert!(result.is_err());
300 }
301
302 #[test]
303 fn incomplete_ordered_vocab() {
304 let vocab_r: HashMap<u32, String> =
305 HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
306
307 let ordered = OrderedVocabIter::new(&vocab_r);
308
309 let serialized = serde_json::to_string(&ordered).unwrap();
310 assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
311 }
312
313 #[test]
314 fn serialization() {
315 let vocab: Vocab = [
316 ("<unk>".into(), 0),
317 ("a".into(), 1),
318 ("b".into(), 2),
319 ("ab".into(), 3),
320 ]
321 .iter()
322 .cloned()
323 .collect();
324 let bpe = BpeBuilder::default()
325 .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
326 .unk_token("<unk>".to_string())
327 .ignore_merges(true)
328 .build()
329 .unwrap();
330
331 let model = ModelWrapper::BPE(bpe);
332 let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
333 let legacy = serde_json::from_str(legacy).unwrap();
334 assert_eq!(model, legacy);
335
336 let data = serde_json::to_string(&model).unwrap();
337 assert_eq!(
338 data,
339 r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
340 );
341 let reconstructed = serde_json::from_str(&data).unwrap();
342 assert_eq!(model, reconstructed);
343
344 let legacy = r#"{"dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
346 let reconstructed = serde_json::from_str(legacy).unwrap();
347 assert_eq!(model, reconstructed);
348
349 let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
350 let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
351 serde_json::from_str(invalid);
352 match reconstructed {
353 Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
354 _ => panic!("Expected an error here"),
355 }
356 }
357}