trs_data_value/
deserializer.rs

1use std::{collections::HashMap, fmt};
2
3use serde::{
4    de::{Deserialize, Deserializer, MapAccess, Visitor},
5    ser::{Serialize, SerializeMap},
6};
7
8use super::DataValue;
9
10// A Visitor is a type that holds methods that a Deserializer can drive
11// depending on what is contained in the input data.
12//
13struct DataValueVisitor;
14
15// This is the trait that Deserializers are going to be driving. There
16// is one method for each type of data that our type knows how to
17// deserialize from.
18impl<'de> Visitor<'de> for DataValueVisitor {
19    // The type that our Visitor is going to produce.
20    type Value = DataValue;
21
22    // Format a message stating what data this Visitor expects to receive.
23    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
24        formatter.write_str("a very special map")
25    }
26
27    fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
28    where
29        E: serde::de::Error,
30    {
31        Ok(DataValue::Bool(v))
32    }
33
34    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
35    where
36        E: serde::de::Error,
37    {
38        Ok(DataValue::I64(v))
39    }
40
41    fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
42    where
43        E: serde::de::Error,
44    {
45        Ok(DataValue::U64(v))
46    }
47
48    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
49    where
50        E: serde::de::Error,
51    {
52        Ok(DataValue::F64(v))
53    }
54
55    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
56    where
57        E: serde::de::Error,
58    {
59        Ok(DataValue::String(v.into()))
60    }
61
62    fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
63    where
64        E: serde::de::Error,
65    {
66        Ok(DataValue::String(v.into()))
67    }
68
69    #[inline]
70    fn visit_unit<E>(self) -> Result<Self::Value, E> {
71        Ok(DataValue::Null)
72    }
73
74    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
75    where
76        E: serde::de::Error,
77    {
78        Ok(DataValue::F32(v))
79    }
80
81    fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
82    where
83        E: serde::de::Error,
84    {
85        Ok(DataValue::U8(v))
86    }
87
88    fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
89    where
90        E: serde::de::Error,
91    {
92        Ok(DataValue::I128(v))
93    }
94
95    fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
96    where
97        E: serde::de::Error,
98    {
99        Ok(DataValue::U128(v))
100    }
101
102    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
103    where
104        E: serde::de::Error,
105    {
106        Ok(DataValue::Bytes(v.into()))
107    }
108
109    fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
110    where
111        E: serde::de::Error,
112    {
113        Ok(DataValue::Bytes(v))
114    }
115
116    fn visit_none<E>(self) -> Result<Self::Value, E>
117    where
118        E: serde::de::Error,
119    {
120        Ok(DataValue::Null)
121    }
122
123    fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
124    where
125        D: Deserializer<'de>,
126    {
127        Deserialize::deserialize(deserializer)
128    }
129
130    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
131    where
132        A: serde::de::SeqAccess<'de>,
133    {
134        let mut vec = Vec::new();
135        while let Ok(Some(value)) = seq.next_element() {
136            vec.push(value);
137        }
138        Ok(DataValue::Vec(vec))
139    }
140
141    // Deserialize MyMap from an abstract "map" provided by the
142    // Deserializer. The MapAccess input is a callback provided by
143    // the Deserializer to let us see each entry in the map.
144    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
145    where
146        M: MapAccess<'de>,
147    {
148        let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
149
150        // While there are entries remaining in the input, add them
151        // into our map.
152        while let Some((key, value)) = access.next_entry()? {
153            map.insert(key, value);
154        }
155
156        Ok(DataValue::Map(map))
157    }
158}
159
160// This is the trait that informs Serde how to deserialize MyMap.
161impl<'de> Deserialize<'de> for DataValue {
162    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
163    where
164        D: Deserializer<'de>,
165    {
166        deserializer.deserialize_any(DataValueVisitor)
167    }
168}
169
170impl Serialize for DataValue {
171    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
172    where
173        S: serde::Serializer,
174    {
175        match self {
176            DataValue::Bool(v) => serializer.serialize_bool(*v),
177            DataValue::I64(v) => serializer.serialize_i64(*v),
178            DataValue::U64(v) => serializer.serialize_u64(*v),
179            DataValue::F64(v) => serializer.serialize_f64(*v),
180            DataValue::String(v) => serializer.serialize_str(v),
181            DataValue::F32(v) => serializer.serialize_f32(*v),
182            DataValue::U8(v) => serializer.serialize_u8(*v),
183            DataValue::I128(v) => serializer.serialize_i128(*v),
184            DataValue::U128(v) => serializer.serialize_u128(*v),
185            DataValue::Bytes(v) => serializer.serialize_bytes(v),
186            DataValue::Vec(v) => v.serialize(serializer),
187            DataValue::Map(v) => {
188                // v.serialize(serializer)
189                let mut map = serializer.serialize_map(Some(v.len()))?;
190                for (k, v) in v {
191                    map.serialize_entry(k, v)?;
192                }
193                map.end()
194            }
195            DataValue::Null => serializer.serialize_none(),
196            DataValue::I32(v) => serializer.serialize_i32(*v),
197            DataValue::U32(v) => serializer.serialize_u32(*v),
198            DataValue::EnumNumber(v) => serializer.serialize_i32(*v),
199        }
200    }
201}
202
203#[cfg(test)]
204mod test {
205
206    use super::*;
207    use rstest::*;
208
209    #[rstest]
210    #[case::bool(true, DataValue::Bool(true))]
211    #[case::i64(-42i64, DataValue::I64(-42))]
212    #[case::f64(42.0f64, DataValue::F64(42.0))]
213    #[case::str("test", DataValue::String("test".into()))]
214    #[case::string("test".to_string(), DataValue::String("test".into()))]
215    #[case::seq(vec![DataValue::I64(-1), DataValue::I64(-2)], DataValue::Vec(vec![DataValue::I64(-1), DataValue::I64(-2)]))]
216    #[case::map({
217        let mut map = HashMap::new();
218        map.insert("key".to_string(), DataValue::U64(42));
219        map
220    }, DataValue::Map(crate::stdhashmap!("key" => DataValue::U64(42))))]
221    fn test_deserialize(#[case] input: impl Into<DataValue>, #[case] expected: DataValue) {
222        let input = input.into();
223        let serialized = serde_json::to_value(&input);
224        assert!(serialized.is_ok(), "{:?}", serialized);
225        println!("{:?}", serialized);
226        let deserialized: Result<DataValue, _> = serde_json::from_value(serialized.unwrap());
227        assert!(deserialized.is_ok(), "{:?}", deserialized);
228        assert_eq!(deserialized.unwrap(), expected);
229    }
230
231    #[derive(Debug, thiserror::Error)]
232    enum DummyError {
233        #[error("Custom error: {0}")]
234        Custom(String),
235    }
236    impl serde::de::Error for DummyError {
237        #[cold]
238        fn custom<T: fmt::Display>(msg: T) -> Self {
239            Self::Custom(msg.to_string())
240        }
241    }
242
243    #[rstest]
244    fn test_visitor() {
245        let v = DataValueVisitor.visit_bool::<DummyError>(true);
246        assert!(v.is_ok());
247        assert_eq!(v.unwrap(), DataValue::Bool(true));
248        let v = DataValueVisitor.visit_i64::<DummyError>(-42);
249        assert!(v.is_ok());
250        assert_eq!(v.unwrap(), DataValue::I64(-42));
251        let v = DataValueVisitor.visit_u64::<DummyError>(42);
252        assert!(v.is_ok());
253        assert_eq!(v.unwrap(), DataValue::U64(42));
254        let v = DataValueVisitor.visit_f64::<DummyError>(42.0);
255        assert!(v.is_ok());
256        assert_eq!(v.unwrap(), DataValue::F64(42.0));
257        let v = DataValueVisitor.visit_str::<DummyError>("test");
258        assert!(v.is_ok());
259        assert_eq!(v.unwrap(), DataValue::String("test".into()));
260        let v = DataValueVisitor.visit_string::<DummyError>("test".to_string());
261        assert!(v.is_ok());
262        assert_eq!(v.unwrap(), DataValue::String("test".into()));
263        let v = DataValueVisitor.visit_f32::<DummyError>(42.0);
264        assert!(v.is_ok());
265        assert_eq!(v.unwrap(), DataValue::F32(42.0));
266        let v = DataValueVisitor.visit_u8::<DummyError>(42);
267        assert!(v.is_ok());
268        assert_eq!(v.unwrap(), DataValue::U8(42));
269        let v = DataValueVisitor.visit_i128::<DummyError>(i128::MAX);
270        assert!(v.is_ok());
271        assert_eq!(v.unwrap(), DataValue::I128(i128::MAX));
272        let v = DataValueVisitor.visit_u128::<DummyError>(u128::MAX);
273        assert!(v.is_ok());
274        assert_eq!(v.unwrap(), DataValue::U128(u128::MAX));
275        let v = DataValueVisitor.visit_bytes::<DummyError>(b"test");
276        assert!(v.is_ok());
277        assert_eq!(v.unwrap(), DataValue::Bytes(b"test".to_vec()));
278        let v = DataValueVisitor.visit_byte_buf::<DummyError>(b"test".to_vec());
279        assert!(v.is_ok());
280        assert_eq!(v.unwrap(), DataValue::Bytes(b"test".to_vec()));
281        let v = DataValueVisitor.visit_none::<DummyError>();
282        assert!(v.is_ok());
283        assert_eq!(v.unwrap(), DataValue::Null);
284    }
285
286    #[rstest]
287    fn serde_simple() {
288        let v: Result<Vec<DataValue>, _> = serde_json::from_str(
289            r#"[
290            253780,
291            0.009369421750307085,
292            1633222860381359,
293            8,
294            5,
295            true,
296            0.16074353018902807,
297            0.4461714007722576,
298            null,
299            0.3,
300            0.3,
301            0.3,
302            -4.660890306625259,
303            null,
304            0
305        ]"#,
306        );
307        assert!(v.is_ok(), "{v:?}");
308    }
309}