1use std::marker::PhantomData;
2
3use serde::{
4 self,
5 de::{Error, MapAccess, Visitor},
6 ser::SerializeStruct,
7 Deserialize, Deserializer, Serialize, Serializer,
8};
9
10use super::{added_vocabulary::AddedTokenWithId, TokenizerImpl};
11use crate::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer, TokenizerBuilder};
12
13static SERIALIZATION_VERSION: &str = "1.0";
14
15impl<M, N, PT, PP, D> Serialize for TokenizerImpl<M, N, PT, PP, D>
16where
17 M: Serialize,
18 N: Serialize,
19 PT: Serialize,
20 PP: Serialize,
21 D: Serialize,
22{
23 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24 where
25 S: Serializer,
26 {
27 let mut tokenizer = serializer.serialize_struct("Tokenizer", 9)?;
28
29 tokenizer.serialize_field("version", SERIALIZATION_VERSION)?;
31
32 tokenizer.serialize_field("truncation", &self.truncation)?;
34 tokenizer.serialize_field("padding", &self.padding)?;
35
36 tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?;
38
39 tokenizer.serialize_field("normalizer", &self.normalizer)?;
41 tokenizer.serialize_field("pre_tokenizer", &self.pre_tokenizer)?;
42 tokenizer.serialize_field("post_processor", &self.post_processor)?;
43 tokenizer.serialize_field("decoder", &self.decoder)?;
44 tokenizer.serialize_field("model", &self.model)?;
45
46 tokenizer.end()
47 }
48}
49
50impl<'de, M, N, PT, PP, D> Deserialize<'de> for TokenizerImpl<M, N, PT, PP, D>
51where
52 M: Deserialize<'de> + Model,
53 N: Deserialize<'de> + Normalizer,
54 PT: Deserialize<'de> + PreTokenizer,
55 PP: Deserialize<'de> + PostProcessor,
56 D: Deserialize<'de> + Decoder,
57{
58 fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
59 where
60 De: Deserializer<'de>,
61 {
62 deserializer.deserialize_struct(
63 "Tokenizer",
64 &[
65 "version",
66 "truncation",
67 "padding",
68 "added_tokens",
69 "normalizer",
70 "pre_tokenizer",
71 "post_processor",
72 "decoder",
73 "model",
74 ],
75 TokenizerVisitor(
76 PhantomData,
77 PhantomData,
78 PhantomData,
79 PhantomData,
80 PhantomData,
81 ),
82 )
83 }
84}
85
86struct TokenizerVisitor<M, N, PT, PP, D>(
87 PhantomData<M>,
88 PhantomData<N>,
89 PhantomData<PT>,
90 PhantomData<PP>,
91 PhantomData<D>,
92);
93
94impl<'de, M, N, PT, PP, D> Visitor<'de> for TokenizerVisitor<M, N, PT, PP, D>
95where
96 M: Deserialize<'de> + Model,
97 N: Deserialize<'de> + Normalizer,
98 PT: Deserialize<'de> + PreTokenizer,
99 PP: Deserialize<'de> + PostProcessor,
100 D: Deserialize<'de> + Decoder,
101{
102 type Value = TokenizerImpl<M, N, PT, PP, D>;
103
104 fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
105 write!(fmt, "struct Tokenizer")
106 }
107
108 fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
109 where
110 V: MapAccess<'de>,
111 {
112 let mut builder = TokenizerBuilder::new();
113 let mut tokens: Vec<AddedTokenWithId> = vec![];
114 while let Some(key) = map.next_key::<String>()? {
115 match key.as_ref() {
116 "version" => {
117 let v: String = map.next_value()?;
118 if &v != "1.0" {
119 return Err(Error::custom(format!("Unknown tokenizer version '{v}'")));
120 }
121 }
122 "truncation" => {
123 builder = builder.with_truncation(map.next_value()?);
124 }
125 "padding" => {
126 builder = builder.with_padding(map.next_value()?);
127 }
128 "added_tokens" => {
129 tokens = map.next_value()?;
130 }
131 "normalizer" => {
132 builder = builder.with_normalizer(map.next_value()?);
133 }
134 "pre_tokenizer" => {
135 builder = builder.with_pre_tokenizer(map.next_value()?);
136 }
137 "model" => {
138 builder = builder.with_model(map.next_value()?);
139 }
140 "decoder" => {
141 builder = builder.with_decoder(map.next_value()?);
142 }
143 "post_processor" => {
144 builder = builder.with_post_processor(map.next_value()?);
145 }
146 _ => {}
147 };
148 }
149 let mut tokenizer = builder
150 .build()
151 .map_err(|e| V::Error::custom(e.to_string()))?;
152
153 for token in &tokens {
156 let received_id = tokenizer.token_to_id(&token.token.content);
158 if let Some(rid) = received_id {
159 if rid != token.id {
160 warn!(
161 "Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'",
162 token.token.content,
163 token.id,
164 rid.to_string()
165 );
166 }
167 }
168 }
169 let added_tokens: Vec<_> = tokens.into_iter().map(|token| token.token).collect();
170 tokenizer.add_tokens(&added_tokens[..]);
171
172 Ok(tokenizer)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use crate::tokenizer::Tokenizer;
179 use std::str::FromStr;
180
181 #[test]
182 fn test_deserialization_serialization_invariant() {
183 let tok_json = r#"{
184 "version": "1.0",
185 "truncation": null,
186 "padding": null,
187 "added_tokens": [
188 {
189 "id": 0,
190 "content": "[SPECIAL_0]",
191 "single_word": false,
192 "lstrip": false,
193 "rstrip": false,
194 "normalized": false,
195 "special": true
196 },
197 {
198 "id": 1,
199 "content": "[SPECIAL_1]",
200 "single_word": false,
201 "lstrip": false,
202 "rstrip": false,
203 "normalized": true,
204 "special": false
205 },
206 {
207 "id": 2,
208 "content": "[SPECIAL_2]",
209 "single_word": false,
210 "lstrip": false,
211 "rstrip": false,
212 "normalized": false,
213 "special": true
214 }
215 ],
216 "normalizer": null,
217 "pre_tokenizer": null,
218 "post_processor": null,
219 "decoder": null,
220 "model": {
221 "type": "WordPiece",
222 "unk_token": "[UNK]",
223 "continuing_subword_prefix": "",
224 "max_input_chars_per_word": 100,
225 "vocab": {}
226 }
227}"#;
228 let tokenizer = Tokenizer::from_str(tok_json).unwrap();
229
230 let tok_str = serde_json::to_string_pretty(&tokenizer).unwrap();
231 assert_eq!(tok_str, tok_json);
233 }
234
235 #[cfg(feature = "http")]
236 #[test]
237 fn test_from_pretrained() {
238 tracing_subscriber::fmt()
239 .with_max_level(tracing::Level::DEBUG)
240 .with_target(false)
241 .init();
242 let _ = Tokenizer::from_pretrained("Qwen/Qwen2-7B-Instruct", None);
243 warn!("This should be the first warning");
244 }
245}