tokenizers/models/wordlevel/
serialization.rs1use super::{super::OrderedVocabIter, WordLevel, WordLevelBuilder};
2use serde::{
3 de::{MapAccess, Visitor},
4 ser::SerializeStruct,
5 Deserialize, Deserializer, Serialize, Serializer,
6};
7use std::collections::HashSet;
8
9impl Serialize for WordLevel {
10 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11 where
12 S: Serializer,
13 {
14 let mut model = serializer.serialize_struct("WordLevel", 3)?;
15 let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
16 model.serialize_field("type", "WordLevel")?;
17 model.serialize_field("vocab", &ordered_vocab)?;
18 model.serialize_field("unk_token", &self.unk_token)?;
19 model.end()
20 }
21}
22
23impl<'de> Deserialize<'de> for WordLevel {
24 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25 where
26 D: Deserializer<'de>,
27 {
28 deserializer.deserialize_struct(
29 "WordLevel",
30 &["type", "vocab", "unk_token"],
31 WordLevelVisitor,
32 )
33 }
34}
35
36struct WordLevelVisitor;
37impl<'de> Visitor<'de> for WordLevelVisitor {
38 type Value = WordLevel;
39
40 fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
41 write!(fmt, "struct WordLevel")
42 }
43
44 fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
45 where
46 V: MapAccess<'de>,
47 {
48 let mut builder = WordLevelBuilder::new();
49 let mut missing_fields = vec![
50 "unk_token",
52 "vocab",
53 ]
54 .into_iter()
55 .collect::<HashSet<_>>();
56 while let Some(key) = map.next_key::<String>()? {
57 match key.as_ref() {
58 "vocab" => builder = builder.vocab(map.next_value()?),
59 "unk_token" => builder = builder.unk_token(map.next_value()?),
60 "type" => match map.next_value()? {
61 "WordLevel" => {}
62 u => {
63 return Err(serde::de::Error::invalid_value(
64 serde::de::Unexpected::Str(u),
65 &"WordLevel",
66 ))
67 }
68 },
69 _ => {}
70 }
71 missing_fields.remove::<str>(&key);
72 }
73
74 if !missing_fields.is_empty() {
75 Err(serde::de::Error::missing_field(
76 missing_fields.iter().next().unwrap(),
77 ))
78 } else {
79 Ok(builder.build().map_err(serde::de::Error::custom)?)
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};
87
88 #[test]
89 fn serde() {
90 let wl = WordLevel::default();
91 let wl_s = r#"{"type":"WordLevel","vocab":{},"unk_token":"<unk>"}"#;
92
93 assert_eq!(serde_json::to_string(&wl).unwrap(), wl_s);
94 assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
95 }
96
97 #[test]
98 fn incomplete_vocab() {
99 let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
100 .iter()
101 .cloned()
102 .collect();
103 let wordlevel = WordLevelBuilder::default()
104 .vocab(vocab)
105 .unk_token("<unk>".to_string())
106 .build()
107 .unwrap();
108 let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
109 assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
110 assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
111 }
112
113 #[test]
114 fn deserialization_should_fail() {
115 let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;
116 assert!(serde_json::from_str::<WordLevel>(missing_unk)
117 .unwrap_err()
118 .to_string()
119 .starts_with("missing field `unk_token`"));
120
121 let wrong_type = r#"{"type":"WordPiece","vocab":{}}"#;
122 assert!(serde_json::from_str::<WordLevel>(wrong_type)
123 .unwrap_err()
124 .to_string()
125 .starts_with("invalid value: string \"WordPiece\", expected WordLevel"));
126 }
127}