tf_idf_vectorizer/utils/datastruct/map/
serde.rs

1use std::{
2    fmt,
3    hash::Hash,
4    marker::PhantomData,
5};
6
7use serde::{Deserialize, Deserializer, Serialize, ser::SerializeStruct};
8
9use crate::utils::datastruct::map::IndexMap;
10
11impl<K, V> Serialize for IndexMap<K, V>
12where
13    K: Serialize,
14    V: Serialize,
15{
16    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
17    where
18        S: serde::Serializer,
19    {
20        let mut state = serializer.serialize_struct("IndexMap", 2)?;
21        state.serialize_field("values", &self.values)?;
22        state.serialize_field("keys", &self.index_set.keys)?;
23        state.end()
24    }
25}
26
27impl<'de, K, V> Deserialize<'de> for IndexMap<K, V>
28where
29    K: Deserialize<'de> + Hash + Eq + Clone,
30    V: Deserialize<'de>,
31{
32    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33    where
34        D: Deserializer<'de>,
35    {
36        #[derive(Debug)]
37        enum Field {
38            Values,
39            Keys,
40        }
41
42        impl<'de> Deserialize<'de> for Field {
43            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44            where
45                D: Deserializer<'de>,
46            {
47                struct FieldVisitor;
48
49                impl<'de> serde::de::Visitor<'de> for FieldVisitor {
50                    type Value = Field;
51
52                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53                        formatter.write_str("`values` or `keys`")
54                    }
55
56                    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
57                    where
58                        E: serde::de::Error,
59                    {
60                        match v {
61                            "values" => Ok(Field::Values),
62                            "keys" => Ok(Field::Keys),
63                            _ => Err(E::unknown_field(v, FIELDS)),
64                        }
65                    }
66
67                    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
68                    where
69                        E: serde::de::Error,
70                    {
71                        match v {
72                            b"values" => Ok(Field::Values),
73                            b"keys" => Ok(Field::Keys),
74                            _ => {
75                                let s = std::str::from_utf8(v).unwrap_or("");
76                                Err(E::unknown_field(s, FIELDS))
77                            }
78                        }
79                    }
80                }
81
82                deserializer.deserialize_identifier(FieldVisitor)
83            }
84        }
85
86        struct IndexMapVisitor<K, V>(PhantomData<(K, V)>);
87
88        const FIELDS: &[&str] = &["values", "keys"];
89
90        impl<'de, K, V> serde::de::Visitor<'de> for IndexMapVisitor<K, V>
91        where
92            K: Deserialize<'de> + Hash + Eq + Clone,
93            V: Deserialize<'de>,
94        {
95            type Value = IndexMap<K, V>;
96
97            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
98                formatter.write_str("an IndexMap serialized as { values: Vec<V>, keys: Vec<K> }")
99            }
100
101            fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
102            where
103                M: serde::de::MapAccess<'de>,
104            {
105                use serde::de::Error as DeError;
106
107                let mut values: Option<Vec<V>> = None;
108                let mut keys: Option<Vec<K>> = None;
109
110                while let Some(field) = access.next_key::<Field>()? {
111                    match field {
112                        Field::Values => {
113                            if values.is_some() {
114                                return Err(DeError::duplicate_field("values"));
115                            }
116                            values = Some(access.next_value()?);
117                        }
118                        Field::Keys => {
119                            if keys.is_some() {
120                                return Err(DeError::duplicate_field("keys"));
121                            }
122                            keys = Some(access.next_value()?);
123                        }
124                    }
125                }
126
127                let values = values.ok_or_else(|| DeError::missing_field("values"))?;
128                let keys = keys.ok_or_else(|| DeError::missing_field("keys"))?;
129
130                if keys.len() != values.len() {
131                    return Err(DeError::custom("IndexMap deserialize error: keys and values length mismatch"));
132                }
133
134                Ok(IndexMap::from_kv_vec(keys, values))
135            }
136
137            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
138            where
139                A: serde::de::SeqAccess<'de>,
140            {
141                use serde::de::Error as DeError;
142
143                let values: Vec<V> = seq
144                    .next_element()?
145                    .ok_or_else(|| DeError::invalid_length(0, &self))?;
146                let keys: Vec<K> = seq
147                    .next_element()?
148                    .ok_or_else(|| DeError::invalid_length(1, &self))?;
149
150                if keys.len() != values.len() {
151                    return Err(DeError::custom("IndexMap deserialize error: keys and values length mismatch"));
152                }
153
154                Ok(IndexMap::from_kv_vec(keys, values))
155            }
156        }
157
158        deserializer.deserialize_struct("IndexMap", FIELDS, IndexMapVisitor(PhantomData))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn serde_roundtrip_json_map_format_preserves_order_and_lookup() {
168        let mut m = IndexMap::<String, i64>::new();
169        m.insert("a".to_string(), 10);
170        m.insert("b".to_string(), 20);
171        m.insert("c".to_string(), 30);
172
173        let s = serde_json::to_string(&m).unwrap();
174        let de: IndexMap<String, i64> = serde_json::from_str(&s).unwrap();
175
176        assert_eq!(de.index_set.keys, m.index_set.keys);
177        assert_eq!(de.values, m.values);
178        assert_eq!(de.len(), 3);
179        assert_eq!(de.index_set.hashes.len(), de.len());
180
181        for (k, v) in m.iter() {
182            assert_eq!(de.get(k).copied(), Some(*v));
183        }
184    }
185
186    #[test]
187    fn serde_roundtrip_bincode_seq_format_works() {
188        let mut m = IndexMap::<u64, i64>::new();
189        for i in 0..100u64 {
190            m.insert(i, (i as i64) * -7);
191        }
192
193        let bytes = bincode::serialize(&m).unwrap();
194        let de: IndexMap<u64, i64> = bincode::deserialize(&bytes).unwrap();
195
196        assert_eq!(de.index_set.keys, m.index_set.keys);
197        assert_eq!(de.values, m.values);
198        assert_eq!(de.index_set.hashes.len(), de.len());
199
200        for (k, v) in m.iter() {
201            assert_eq!(de.get(k).copied(), Some(*v));
202        }
203    }
204
205    #[test]
206    fn serde_rejects_len_mismatch_json() {
207        // values.len != keys.len
208        let bad = r#"{\"values\":[1,2,3],\"keys\":[\"a\",\"b\"]}"#;
209        let res = serde_json::from_str::<IndexMap<String, i32>>(bad);
210        assert!(res.is_err());
211    }
212
213    #[test]
214    fn serde_rejects_duplicate_fields_json() {
215        // JSON的には非推奨だが、パーサが許すケースがあるので、ここで弾けると安心
216        let dup = r#"{\"values\":[1],\"values\":[2],\"keys\":[\"a\"]}"#;
217        let res = serde_json::from_str::<IndexMap<String, i32>>(dup);
218        assert!(res.is_err());
219    }
220}