tokenizers/tokenizer/
serialization.rs

1use std::marker::PhantomData;
2
3use serde::{
4    self,
5    de::{Error, MapAccess, Visitor},
6    ser::SerializeStruct,
7    Deserialize, Deserializer, Serialize, Serializer,
8};
9
10use super::{added_vocabulary::AddedTokenWithId, TokenizerImpl};
11use crate::{Decoder, Model, Normalizer, PostProcessor, PreTokenizer, TokenizerBuilder};
12
13static SERIALIZATION_VERSION: &str = "1.0";
14
15impl<M, N, PT, PP, D> Serialize for TokenizerImpl<M, N, PT, PP, D>
16where
17    M: Serialize,
18    N: Serialize,
19    PT: Serialize,
20    PP: Serialize,
21    D: Serialize,
22{
23    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24    where
25        S: Serializer,
26    {
27        let mut tokenizer = serializer.serialize_struct("Tokenizer", 9)?;
28
29        // Start by adding the current version
30        tokenizer.serialize_field("version", SERIALIZATION_VERSION)?;
31
32        // Params
33        tokenizer.serialize_field("truncation", &self.truncation)?;
34        tokenizer.serialize_field("padding", &self.padding)?;
35
36        // Added tokens
37        tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?;
38
39        // Then add our parts
40        tokenizer.serialize_field("normalizer", &self.normalizer)?;
41        tokenizer.serialize_field("pre_tokenizer", &self.pre_tokenizer)?;
42        tokenizer.serialize_field("post_processor", &self.post_processor)?;
43        tokenizer.serialize_field("decoder", &self.decoder)?;
44        tokenizer.serialize_field("model", &self.model)?;
45
46        tokenizer.end()
47    }
48}
49
50impl<'de, M, N, PT, PP, D> Deserialize<'de> for TokenizerImpl<M, N, PT, PP, D>
51where
52    M: Deserialize<'de> + Model,
53    N: Deserialize<'de> + Normalizer,
54    PT: Deserialize<'de> + PreTokenizer,
55    PP: Deserialize<'de> + PostProcessor,
56    D: Deserialize<'de> + Decoder,
57{
58    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
59    where
60        De: Deserializer<'de>,
61    {
62        deserializer.deserialize_struct(
63            "Tokenizer",
64            &[
65                "version",
66                "truncation",
67                "padding",
68                "added_tokens",
69                "normalizer",
70                "pre_tokenizer",
71                "post_processor",
72                "decoder",
73                "model",
74            ],
75            TokenizerVisitor(
76                PhantomData,
77                PhantomData,
78                PhantomData,
79                PhantomData,
80                PhantomData,
81            ),
82        )
83    }
84}
85
86struct TokenizerVisitor<M, N, PT, PP, D>(
87    PhantomData<M>,
88    PhantomData<N>,
89    PhantomData<PT>,
90    PhantomData<PP>,
91    PhantomData<D>,
92);
93
94impl<'de, M, N, PT, PP, D> Visitor<'de> for TokenizerVisitor<M, N, PT, PP, D>
95where
96    M: Deserialize<'de> + Model,
97    N: Deserialize<'de> + Normalizer,
98    PT: Deserialize<'de> + PreTokenizer,
99    PP: Deserialize<'de> + PostProcessor,
100    D: Deserialize<'de> + Decoder,
101{
102    type Value = TokenizerImpl<M, N, PT, PP, D>;
103
104    fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
105        write!(fmt, "struct Tokenizer")
106    }
107
108    fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
109    where
110        V: MapAccess<'de>,
111    {
112        let mut builder = TokenizerBuilder::new();
113        let mut tokens: Vec<AddedTokenWithId> = vec![];
114        while let Some(key) = map.next_key::<String>()? {
115            match key.as_ref() {
116                "version" => {
117                    let v: String = map.next_value()?;
118                    if &v != "1.0" {
119                        return Err(Error::custom(format!("Unknown tokenizer version '{v}'")));
120                    }
121                }
122                "truncation" => {
123                    builder = builder.with_truncation(map.next_value()?);
124                }
125                "padding" => {
126                    builder = builder.with_padding(map.next_value()?);
127                }
128                "added_tokens" => {
129                    tokens = map.next_value()?;
130                }
131                "normalizer" => {
132                    builder = builder.with_normalizer(map.next_value()?);
133                }
134                "pre_tokenizer" => {
135                    builder = builder.with_pre_tokenizer(map.next_value()?);
136                }
137                "model" => {
138                    builder = builder.with_model(map.next_value()?);
139                }
140                "decoder" => {
141                    builder = builder.with_decoder(map.next_value()?);
142                }
143                "post_processor" => {
144                    builder = builder.with_post_processor(map.next_value()?);
145                }
146                _ => {}
147            };
148        }
149        let mut tokenizer = builder
150            .build()
151            .map_err(|e| V::Error::custom(e.to_string()))?;
152
153        // We take care of deserializing the added_tokens (instead of `AddedVocabulary` directly
154        // because it let us check that associated IDs are still good, and warn the user otherwise
155        for token in &tokens {
156            // Warn the user if the id is different than expected
157            let received_id = tokenizer.token_to_id(&token.token.content);
158            if let Some(rid) = received_id {
159                if rid != token.id {
160                    warn!(
161                        "Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'",
162                        token.token.content,
163                        token.id,
164                        rid.to_string()
165                    );
166                }
167            }
168        }
169        let added_tokens: Vec<_> = tokens.into_iter().map(|token| token.token).collect();
170        tokenizer.add_tokens(&added_tokens[..]);
171
172        Ok(tokenizer)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use crate::tokenizer::Tokenizer;
179    use std::str::FromStr;
180
181    #[test]
182    fn test_deserialization_serialization_invariant() {
183        let tok_json = r#"{
184  "version": "1.0",
185  "truncation": null,
186  "padding": null,
187  "added_tokens": [
188    {
189      "id": 0,
190      "content": "[SPECIAL_0]",
191      "single_word": false,
192      "lstrip": false,
193      "rstrip": false,
194      "normalized": false,
195      "special": true
196    },
197    {
198      "id": 1,
199      "content": "[SPECIAL_1]",
200      "single_word": false,
201      "lstrip": false,
202      "rstrip": false,
203      "normalized": true,
204      "special": false
205    },
206    {
207      "id": 2,
208      "content": "[SPECIAL_2]",
209      "single_word": false,
210      "lstrip": false,
211      "rstrip": false,
212      "normalized": false,
213      "special": true
214    }
215  ],
216  "normalizer": null,
217  "pre_tokenizer": null,
218  "post_processor": null,
219  "decoder": null,
220  "model": {
221    "type": "WordPiece",
222    "unk_token": "[UNK]",
223    "continuing_subword_prefix": "",
224    "max_input_chars_per_word": 100,
225    "vocab": {}
226  }
227}"#;
228        let tokenizer = Tokenizer::from_str(tok_json).unwrap();
229
230        let tok_str = serde_json::to_string_pretty(&tokenizer).unwrap();
231        // It should be exactly the same as above
232        assert_eq!(tok_str, tok_json);
233    }
234
235    #[cfg(feature = "http")]
236    #[test]
237    fn test_from_pretrained() {
238        tracing_subscriber::fmt()
239            .with_max_level(tracing::Level::DEBUG)
240            .with_target(false)
241            .init();
242        let _ = Tokenizer::from_pretrained("Qwen/Qwen2-7B-Instruct", None);
243        warn!("This should be the first warning");
244    }
245}