1use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
2use serde::{
3 de::{Error, MapAccess, Visitor},
4 ser::SerializeStruct,
5 Deserialize, Deserializer, Serialize, Serializer,
6};
7use std::collections::HashMap;
8
9impl Serialize for BPE {
10 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11 where
12 S: Serializer,
13 {
14 let mut model = serializer.serialize_struct("BPE", 8)?;
15
16 model.serialize_field("type", "BPE")?;
18 model.serialize_field("dropout", &self.dropout)?;
19 model.serialize_field("unk_token", &self.unk_token)?;
20 model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
21 model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
22 model.serialize_field("fuse_unk", &self.fuse_unk)?;
23 model.serialize_field("byte_fallback", &self.byte_fallback)?;
24 model.serialize_field("ignore_merges", &self.ignore_merges)?;
25
26 let mut merges: Vec<(&Pair, &u32)> = self
28 .merges
29 .iter()
30 .map(|(pair, (rank, _))| (pair, rank))
31 .collect();
32 merges.sort_unstable_by_key(|k| *k.1);
33 let merges = merges
34 .into_iter()
35 .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
36 .collect::<Vec<_>>();
37 let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
38
39 model.serialize_field("vocab", &ordered_vocab)?;
40 model.serialize_field("merges", &merges)?;
41
42 model.end()
43 }
44}
45
46impl<'de> Deserialize<'de> for BPE {
47 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
48 where
49 D: Deserializer<'de>,
50 {
51 deserializer.deserialize_struct(
52 "BPE",
53 &[
54 "type",
55 "dropout",
56 "unk_token",
57 "continuing_subword_prefix",
58 "end_of_word_suffix",
59 "fuse_unk",
60 "byte_fallback",
61 "ignore_merges",
62 "vocab",
63 "merges",
64 ],
65 BPEVisitor,
66 )
67 }
68}
69
70struct BPEVisitor;
71impl<'de> Visitor<'de> for BPEVisitor {
72 type Value = BPE;
73
74 fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
75 write!(fmt, "struct BPE")
76 }
77
78 fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
79 where
80 V: MapAccess<'de>,
81 {
82 let mut builder = BpeBuilder::new();
83 let mut vocab: Option<HashMap<String, u32>> = None;
84
85 #[derive(Debug, Deserialize)]
86 #[serde(untagged)]
87 enum MergeType {
88 Tuple(Vec<(String, String)>),
89 Legacy(Vec<String>),
90 }
91 let mut merges: Option<MergeType> = None;
92 while let Some(key) = map.next_key::<String>()? {
93 match key.as_ref() {
94 "dropout" => {
95 if let Some(dropout) = map.next_value()? {
96 builder = builder.dropout(dropout);
97 }
98 }
99 "unk_token" => {
100 if let Some(unk) = map.next_value()? {
101 builder = builder.unk_token(unk);
102 }
103 }
104 "continuing_subword_prefix" => {
105 if let Some(prefix) = map.next_value()? {
106 builder = builder.continuing_subword_prefix(prefix);
107 }
108 }
109 "end_of_word_suffix" => {
110 if let Some(suffix) = map.next_value()? {
111 builder = builder.end_of_word_suffix(suffix);
112 }
113 }
114 "fuse_unk" => {
115 if let Some(suffix) = map.next_value()? {
116 builder = builder.fuse_unk(suffix);
117 }
118 }
119 "byte_fallback" => {
120 if let Some(suffix) = map.next_value()? {
121 builder = builder.byte_fallback(suffix);
122 }
123 }
124 "ignore_merges" => {
125 if let Some(suffix) = map.next_value()? {
126 builder = builder.ignore_merges(suffix);
127 }
128 }
129 "vocab" => vocab = Some(map.next_value()?),
130 "merges" => merges = Some(map.next_value()?),
131 "type" => match map.next_value()? {
132 "BPE" => {}
133 u => {
134 return Err(serde::de::Error::invalid_value(
135 serde::de::Unexpected::Str(u),
136 &"BPE",
137 ))
138 }
139 },
140 _ => {}
141 }
142 }
143 if let (Some(vocab), Some(merges)) = (vocab, merges) {
144 let merges = match merges {
145 MergeType::Tuple(merges) => merges,
146 MergeType::Legacy(merges) => {
147 convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
148 }
149 };
150 builder = builder.vocab_and_merges(vocab, merges);
151 Ok(builder.build().map_err(Error::custom)?)
152 } else {
153 Err(Error::custom("Missing vocab/merges"))
154 }
155 }
156}
157
158#[cfg(test)]
159mod test {
160 use super::*;
161 use crate::models::bpe::Vocab;
162
163 #[test]
164 fn test_serialization() {
165 let vocab: Vocab = [
166 ("<unk>".into(), 0),
167 ("a".into(), 1),
168 ("b".into(), 2),
169 ("ab".into(), 3),
170 ]
171 .iter()
172 .cloned()
173 .collect();
174 let bpe = BpeBuilder::default()
175 .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
176 .unk_token("<unk>".to_string())
177 .ignore_merges(true)
178 .build()
179 .unwrap();
180
181 let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
182 let legacy = serde_json::from_str(legacy).unwrap();
183 assert_eq!(bpe, legacy);
184
185 let data = serde_json::to_string(&bpe).unwrap();
186 assert_eq!(
187 data,
188 r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
189 );
190 let reconstructed = serde_json::from_str(&data).unwrap();
191 assert_eq!(bpe, reconstructed);
192
193 let vocab: Vocab = [
195 ("<unk>".into(), 0),
196 ("a".into(), 1),
197 ("b c d".into(), 2),
198 ("ab c d".into(), 3),
199 ]
200 .iter()
201 .cloned()
202 .collect();
203 let bpe = BpeBuilder::default()
204 .vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())])
205 .unk_token("<unk>".to_string())
206 .ignore_merges(true)
207 .build()
208 .unwrap();
209 let data = serde_json::to_string(&bpe).unwrap();
210 assert_eq!(
211 data,
212 r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"#
213 );
214 let reconstructed = serde_json::from_str(&data).unwrap();
215 assert_eq!(bpe, reconstructed);
216 }
217
218 #[test]
219 fn test_serialization_ignore_merges() {
220 let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
221 .iter()
222 .cloned()
223 .collect();
224 let mut bpe = BpeBuilder::default()
225 .vocab_and_merges(vocab, vec![])
226 .unk_token("<unk>".to_string())
227 .ignore_merges(true)
228 .build()
229 .unwrap();
230
231 let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
232 assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
233
234 bpe.ignore_merges = false;
235 let bpe_string = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"vocab":{"<unk>":0,"a":1,"b":2},"merges":[]}"#;
236 assert_eq!(serde_json::from_str::<BPE>(bpe_string).unwrap(), bpe);
237 }
238}