tokenizers/models/wordlevel/
serialization.rs

1use super::{super::OrderedVocabIter, WordLevel, WordLevelBuilder};
2use serde::{
3    de::{MapAccess, Visitor},
4    ser::SerializeStruct,
5    Deserialize, Deserializer, Serialize, Serializer,
6};
7use std::collections::HashSet;
8
9impl Serialize for WordLevel {
10    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11    where
12        S: Serializer,
13    {
14        let mut model = serializer.serialize_struct("WordLevel", 3)?;
15        let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
16        model.serialize_field("type", "WordLevel")?;
17        model.serialize_field("vocab", &ordered_vocab)?;
18        model.serialize_field("unk_token", &self.unk_token)?;
19        model.end()
20    }
21}
22
23impl<'de> Deserialize<'de> for WordLevel {
24    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25    where
26        D: Deserializer<'de>,
27    {
28        deserializer.deserialize_struct(
29            "WordLevel",
30            &["type", "vocab", "unk_token"],
31            WordLevelVisitor,
32        )
33    }
34}
35
36struct WordLevelVisitor;
37impl<'de> Visitor<'de> for WordLevelVisitor {
38    type Value = WordLevel;
39
40    fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
41        write!(fmt, "struct WordLevel")
42    }
43
44    fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
45    where
46        V: MapAccess<'de>,
47    {
48        let mut builder = WordLevelBuilder::new();
49        let mut missing_fields = vec![
50            // for retrocompatibility the "type" field is not mandatory
51            "unk_token",
52            "vocab",
53        ]
54        .into_iter()
55        .collect::<HashSet<_>>();
56        while let Some(key) = map.next_key::<String>()? {
57            match key.as_ref() {
58                "vocab" => builder = builder.vocab(map.next_value()?),
59                "unk_token" => builder = builder.unk_token(map.next_value()?),
60                "type" => match map.next_value()? {
61                    "WordLevel" => {}
62                    u => {
63                        return Err(serde::de::Error::invalid_value(
64                            serde::de::Unexpected::Str(u),
65                            &"WordLevel",
66                        ))
67                    }
68                },
69                _ => {}
70            }
71            missing_fields.remove::<str>(&key);
72        }
73
74        if !missing_fields.is_empty() {
75            Err(serde::de::Error::missing_field(
76                missing_fields.iter().next().unwrap(),
77            ))
78        } else {
79            Ok(builder.build().map_err(serde::de::Error::custom)?)
80        }
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};
87
88    #[test]
89    fn serde() {
90        let wl = WordLevel::default();
91        let wl_s = r#"{"type":"WordLevel","vocab":{},"unk_token":"<unk>"}"#;
92
93        assert_eq!(serde_json::to_string(&wl).unwrap(), wl_s);
94        assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
95    }
96
97    #[test]
98    fn incomplete_vocab() {
99        let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
100            .iter()
101            .cloned()
102            .collect();
103        let wordlevel = WordLevelBuilder::default()
104            .vocab(vocab)
105            .unk_token("<unk>".to_string())
106            .build()
107            .unwrap();
108        let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
109        assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
110        assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
111    }
112
113    #[test]
114    fn deserialization_should_fail() {
115        let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;
116        assert!(serde_json::from_str::<WordLevel>(missing_unk)
117            .unwrap_err()
118            .to_string()
119            .starts_with("missing field `unk_token`"));
120
121        let wrong_type = r#"{"type":"WordPiece","vocab":{}}"#;
122        assert!(serde_json::from_str::<WordLevel>(wrong_type)
123            .unwrap_err()
124            .to_string()
125            .starts_with("invalid value: string \"WordPiece\", expected WordLevel"));
126    }
127}