tf_idf_vectorizer/utils/datastruct/map/
serde.rs1use std::{
2 fmt,
3 hash::Hash,
4 marker::PhantomData,
5};
6
7use serde::{Deserialize, Deserializer, Serialize, ser::SerializeStruct};
8
9use crate::utils::datastruct::map::IndexMap;
10
11impl<K, V> Serialize for IndexMap<K, V>
12where
13 K: Serialize,
14 V: Serialize,
15{
16 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
17 where
18 S: serde::Serializer,
19 {
20 let mut state = serializer.serialize_struct("IndexMap", 2)?;
21 state.serialize_field("values", &self.values)?;
22 state.serialize_field("keys", &self.index_set.keys)?;
23 state.end()
24 }
25}
26
27impl<'de, K, V> Deserialize<'de> for IndexMap<K, V>
28where
29 K: Deserialize<'de> + Hash + Eq + Clone,
30 V: Deserialize<'de>,
31{
32 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 #[derive(Debug)]
37 enum Field {
38 Values,
39 Keys,
40 }
41
42 impl<'de> Deserialize<'de> for Field {
43 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct FieldVisitor;
48
49 impl<'de> serde::de::Visitor<'de> for FieldVisitor {
50 type Value = Field;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("`values` or `keys`")
54 }
55
56 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
57 where
58 E: serde::de::Error,
59 {
60 match v {
61 "values" => Ok(Field::Values),
62 "keys" => Ok(Field::Keys),
63 _ => Err(E::unknown_field(v, FIELDS)),
64 }
65 }
66
67 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
68 where
69 E: serde::de::Error,
70 {
71 match v {
72 b"values" => Ok(Field::Values),
73 b"keys" => Ok(Field::Keys),
74 _ => {
75 let s = std::str::from_utf8(v).unwrap_or("");
76 Err(E::unknown_field(s, FIELDS))
77 }
78 }
79 }
80 }
81
82 deserializer.deserialize_identifier(FieldVisitor)
83 }
84 }
85
86 struct IndexMapVisitor<K, V>(PhantomData<(K, V)>);
87
88 const FIELDS: &[&str] = &["values", "keys"];
89
90 impl<'de, K, V> serde::de::Visitor<'de> for IndexMapVisitor<K, V>
91 where
92 K: Deserialize<'de> + Hash + Eq + Clone,
93 V: Deserialize<'de>,
94 {
95 type Value = IndexMap<K, V>;
96
97 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
98 formatter.write_str("an IndexMap serialized as { values: Vec<V>, keys: Vec<K> }")
99 }
100
101 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
102 where
103 M: serde::de::MapAccess<'de>,
104 {
105 use serde::de::Error as DeError;
106
107 let mut values: Option<Vec<V>> = None;
108 let mut keys: Option<Vec<K>> = None;
109
110 while let Some(field) = access.next_key::<Field>()? {
111 match field {
112 Field::Values => {
113 if values.is_some() {
114 return Err(DeError::duplicate_field("values"));
115 }
116 values = Some(access.next_value()?);
117 }
118 Field::Keys => {
119 if keys.is_some() {
120 return Err(DeError::duplicate_field("keys"));
121 }
122 keys = Some(access.next_value()?);
123 }
124 }
125 }
126
127 let values = values.ok_or_else(|| DeError::missing_field("values"))?;
128 let keys = keys.ok_or_else(|| DeError::missing_field("keys"))?;
129
130 if keys.len() != values.len() {
131 return Err(DeError::custom("IndexMap deserialize error: keys and values length mismatch"));
132 }
133
134 Ok(IndexMap::from_kv_vec(keys, values))
135 }
136
137 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
138 where
139 A: serde::de::SeqAccess<'de>,
140 {
141 use serde::de::Error as DeError;
142
143 let values: Vec<V> = seq
144 .next_element()?
145 .ok_or_else(|| DeError::invalid_length(0, &self))?;
146 let keys: Vec<K> = seq
147 .next_element()?
148 .ok_or_else(|| DeError::invalid_length(1, &self))?;
149
150 if keys.len() != values.len() {
151 return Err(DeError::custom("IndexMap deserialize error: keys and values length mismatch"));
152 }
153
154 Ok(IndexMap::from_kv_vec(keys, values))
155 }
156 }
157
158 deserializer.deserialize_struct("IndexMap", FIELDS, IndexMapVisitor(PhantomData))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn serde_roundtrip_json_map_format_preserves_order_and_lookup() {
168 let mut m = IndexMap::<String, i64>::new();
169 m.insert("a".to_string(), 10);
170 m.insert("b".to_string(), 20);
171 m.insert("c".to_string(), 30);
172
173 let s = serde_json::to_string(&m).unwrap();
174 let de: IndexMap<String, i64> = serde_json::from_str(&s).unwrap();
175
176 assert_eq!(de.index_set.keys, m.index_set.keys);
177 assert_eq!(de.values, m.values);
178 assert_eq!(de.len(), 3);
179 assert_eq!(de.index_set.hashes.len(), de.len());
180
181 for (k, v) in m.iter() {
182 assert_eq!(de.get(k).copied(), Some(*v));
183 }
184 }
185
186 #[test]
187 fn serde_roundtrip_bincode_seq_format_works() {
188 let mut m = IndexMap::<u64, i64>::new();
189 for i in 0..100u64 {
190 m.insert(i, (i as i64) * -7);
191 }
192
193 let bytes = bincode::serialize(&m).unwrap();
194 let de: IndexMap<u64, i64> = bincode::deserialize(&bytes).unwrap();
195
196 assert_eq!(de.index_set.keys, m.index_set.keys);
197 assert_eq!(de.values, m.values);
198 assert_eq!(de.index_set.hashes.len(), de.len());
199
200 for (k, v) in m.iter() {
201 assert_eq!(de.get(k).copied(), Some(*v));
202 }
203 }
204
205 #[test]
206 fn serde_rejects_len_mismatch_json() {
207 let bad = r#"{\"values\":[1,2,3],\"keys\":[\"a\",\"b\"]}"#;
209 let res = serde_json::from_str::<IndexMap<String, i32>>(bad);
210 assert!(res.is_err());
211 }
212
213 #[test]
214 fn serde_rejects_duplicate_fields_json() {
215 let dup = r#"{\"values\":[1],\"values\":[2],\"keys\":[\"a\"]}"#;
217 let res = serde_json::from_str::<IndexMap<String, i32>>(dup);
218 assert!(res.is_err());
219 }
220}