tokenizers/models/bpe/
serialization.rs

1use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
2use serde::{
3    de::{Error, MapAccess, Visitor},
4    ser::SerializeStruct,
5    Deserialize, Deserializer, Serialize, Serializer,
6};
7use std::collections::HashMap;
8
9impl Serialize for BPE {
10    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11    where
12        S: Serializer,
13    {
14        let mut model = serializer.serialize_struct("BPE", 8)?;
15
16        // Start by small fields
17        model.serialize_field("type", "BPE")?;
18        model.serialize_field("dropout", &self.dropout)?;
19        model.serialize_field("unk_token", &self.unk_token)?;
20        model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
21        model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
22        model.serialize_field("fuse_unk", &self.fuse_unk)?;
23        model.serialize_field("byte_fallback", &self.byte_fallback)?;
24        model.serialize_field("ignore_merges", &self.ignore_merges)?;
25
26        // Then the large ones
27        let mut merges: Vec<(&Pair, &u32)> = self
28            .merges
29            .iter()
30            .map(|(pair, (rank, _))| (pair, rank))
31            .collect();
32        merges.sort_unstable_by_key(|k| *k.1);
33        let merges = merges
34            .into_iter()
35            .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
36            .collect::<Vec<_>>();
37        let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
38
39        model.serialize_field("vocab", &ordered_vocab)?;
40        model.serialize_field("merges", &merges)?;
41
42        model.end()
43    }
44}
45
46impl<'de> Deserialize<'de> for BPE {
47    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
48    where
49        D: Deserializer<'de>,
50    {
51        deserializer.deserialize_struct(
52            "BPE",
53            &[
54                "type",
55                "dropout",
56                "unk_token",
57                "continuing_subword_prefix",
58                "end_of_word_suffix",
59                "fuse_unk",
60                "byte_fallback",
61                "ignore_merges",
62                "vocab",
63                "merges",
64            ],
65            BPEVisitor,
66        )
67    }
68}
69
70struct BPEVisitor;
71impl<'de> Visitor<'de> for BPEVisitor {
72    type Value = BPE;
73
74    fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
75        write!(fmt, "struct BPE")
76    }
77
78    fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
79    where
80        V: MapAccess<'de>,
81    {
82        let mut builder = BpeBuilder::new();
83        let mut vocab: Option<HashMap<String, u32>> = None;
84
85        #[derive(Debug, Deserialize)]
86        #[serde(untagged)]
87        enum MergeType {
88            Tuple(Vec<(String, String)>),
89            Legacy(Vec<String>),
90        }
91        let mut merges: Option<MergeType> = None;
92        while let Some(key) = map.next_key::<String>()? {
93            match key.as_ref() {
94                "dropout" => {
95                    if let Some(dropout) = map.next_value()? {
96                        builder = builder.dropout(dropout);
97                    }
98                }
99                "unk_token" => {
100                    if let Some(unk) = map.next_value()? {
101                        builder = builder.unk_token(unk);
102                    }
103                }
104                "continuing_subword_prefix" => {
105                    if let Some(prefix) = map.next_value()? {
106                        builder = builder.continuing_subword_prefix(prefix);
107                    }
108                }
109                "end_of_word_suffix" => {
110                    if let Some(suffix) = map.next_value()? {
111                        builder = builder.end_of_word_suffix(suffix);
112                    }
113                }
114                "fuse_unk" => {
115                    if let Some(suffix) = map.next_value()? {
116                        builder = builder.fuse_unk(suffix);
117                    }
118                }
119                "byte_fallback" => {
120                    if let Some(suffix) = map.next_value()? {
121                        builder = builder.byte_fallback(suffix);
122                    }
123                }
124                "ignore_merges" => {
125                    if let Some(suffix) = map.next_value()? {
126                        builder = builder.ignore_merges(suffix);
127                    }
128                }
129                "vocab" => vocab = Some(map.next_value()?),
130                "merges" => merges = Some(map.next_value()?),
131                "type" => match map.next_value()? {
132                    "BPE" => {}
133                    u => {
134                        return Err(serde::de::Error::invalid_value(
135                            serde::de::Unexpected::Str(u),
136                            &"BPE",
137                        ))
138                    }
139                },
140                _ => {}
141            }
142        }
143        if let (Some(vocab), Some(merges)) = (vocab, merges) {
144            let merges = match merges {
145                MergeType::Tuple(merges) => merges,
146                MergeType::Legacy(merges) => {
147                    convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
148                }
149            };
150            builder = builder.vocab_and_merges(vocab, merges);
151            Ok(builder.build().map_err(Error::custom)?)
152        } else {
153            Err(Error::custom("Missing vocab/merges"))
154        }
155    }
156}
157
158#[cfg(test)]
159mod test {
160    use super::*;
161    use crate::models::bpe::Vocab;
162
163    #[test]
164    fn test_serialization() {
165        let vocab: Vocab = [
166            ("<unk>".into(), 0),
167            ("a".into(), 1),
168            ("b".into(), 2),
169            ("ab".into(), 3),
170        ]
171        .iter()
172        .cloned()
173        .collect();
174        let bpe = BpeBuilder::default()
175            .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
176            .unk_token("<unk>".to_string())
177            .ignore_merges(true)
178            .build()
179            .unwrap();
180
181        let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
182        let legacy = serde_json::from_str(legacy).unwrap();
183        assert_eq!(bpe, legacy);
184
185        let data = serde_json::to_string(&bpe).unwrap();
186        assert_eq!(
187            data,
188            r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
189        );
190        let reconstructed = serde_json::from_str(&data).unwrap();
191        assert_eq!(bpe, reconstructed);
192
193        // With a space in the token
194        let vocab: Vocab = [
195            ("<unk>".into(), 0),
196            ("a".into(), 1),
197            ("b c d".into(), 2),
198            ("ab c d".into(), 3),
199        ]
200        .iter()
201        .cloned()
202        .collect();
203        let bpe = BpeBuilder::default()
204            .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())])
205            .unk_token("<unk>".to_string())
206            .ignore_merges(true)
207            .build()
208            .unwrap();
209        let data = serde_json::to_string(&bpe).unwrap();
210        assert_eq!(
211            data,
212            r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"#
213        );
214        let reconstructed = serde_json::from_str(&data).unwrap();
215        assert_eq!(bpe, reconstructed);
216    }
217
218    #[test]
219    fn test_serialization_ignore_merges() {
220        let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
221            .iter()
222            .cloned()
223            .collect();
224        let mut bpe = BpeBuilder::default()
225            .vocab_and_merges(vocab, vec![])
226            .unk_token("<unk>".to_string())
227            .ignore_merges(true)
228            .build()
229            .unwrap();
230
231        let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
232        assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
233
234        bpe.ignore_merges = false;
235        let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
236        assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
237    }
238}