unc_primitives_core/
serialize.rs

1use base64::display::Base64Display;
2use base64::engine::general_purpose::GeneralPurpose;
3use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
4use base64::Engine;
5
6pub fn to_base64(input: &[u8]) -> String {
7    BASE64_STANDARD.encode(input)
8}
9
10pub fn base64_display(input: &[u8]) -> Base64Display<'_, 'static, GeneralPurpose> {
11    Base64Display::new(input, &BASE64_STANDARD)
12}
13
14pub fn from_base64(encoded: &str) -> Result<Vec<u8>, base64::DecodeError> {
15    BASE64_STANDARD.decode(encoded)
16}
17
18/// Serialises number as a string; deserialises either as a string or number.
19///
20/// This format works for `u64`, `u128`, `Option<u64>` and `Option<u128>` types.
21/// When serialising, numbers are serialised as decimal strings.  When
22/// deserialising, strings are parsed as decimal numbers while numbers are
23/// interpreted as is.
24pub mod dec_format {
25    use serde::de;
26    use serde::{Deserializer, Serializer};
27
28    #[derive(thiserror::Error, Debug)]
29    #[error("cannot parse from unit")]
30    pub struct ParseUnitError;
31
32    /// Abstraction between integers that we serialise.
33    pub trait DecType: Sized {
34        /// Formats number as a decimal string; passes `None` as is.
35        fn serialize(&self) -> Option<String>;
36
37        /// Constructs Self from a `null` value.  Returns error if this type
38        /// does not accept `null` values.
39        fn try_from_unit() -> Result<Self, ParseUnitError> {
40            Err(ParseUnitError)
41        }
42
43        /// Tries to parse decimal string as an integer.
44        fn try_from_str(value: &str) -> Result<Self, std::num::ParseIntError>;
45
46        /// Constructs Self from a 64-bit unsigned integer.
47        fn from_u64(value: u64) -> Self;
48    }
49
50    impl DecType for u64 {
51        fn serialize(&self) -> Option<String> {
52            Some(self.to_string())
53        }
54        fn try_from_str(value: &str) -> Result<Self, std::num::ParseIntError> {
55            Self::from_str_radix(value, 10)
56        }
57        fn from_u64(value: u64) -> Self {
58            value
59        }
60    }
61
62    impl DecType for u128 {
63        fn serialize(&self) -> Option<String> {
64            Some(self.to_string())
65        }
66        fn try_from_str(value: &str) -> Result<Self, std::num::ParseIntError> {
67            Self::from_str_radix(value, 10)
68        }
69        fn from_u64(value: u64) -> Self {
70            value.into()
71        }
72    }
73
74    impl<T: DecType> DecType for Option<T> {
75        fn serialize(&self) -> Option<String> {
76            self.as_ref().and_then(DecType::serialize)
77        }
78        fn try_from_unit() -> Result<Self, ParseUnitError> {
79            Ok(None)
80        }
81        fn try_from_str(value: &str) -> Result<Self, std::num::ParseIntError> {
82            Some(T::try_from_str(value)).transpose()
83        }
84        fn from_u64(value: u64) -> Self {
85            Some(T::from_u64(value))
86        }
87    }
88
89    struct Visitor<T>(core::marker::PhantomData<T>);
90
91    impl<'de, T: DecType> de::Visitor<'de> for Visitor<T> {
92        type Value = T;
93
94        fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
95            fmt.write_str("a non-negative integer as a string")
96        }
97
98        fn visit_unit<E: de::Error>(self) -> Result<T, E> {
99            T::try_from_unit().map_err(|_| de::Error::invalid_type(de::Unexpected::Option, &self))
100        }
101
102        fn visit_u64<E: de::Error>(self, value: u64) -> Result<T, E> {
103            Ok(T::from_u64(value))
104        }
105
106        fn visit_str<E: de::Error>(self, value: &str) -> Result<T, E> {
107            T::try_from_str(value).map_err(de::Error::custom)
108        }
109    }
110
111    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
112    where
113        D: Deserializer<'de>,
114        T: DecType,
115    {
116        deserializer.deserialize_any(Visitor(Default::default()))
117    }
118
119    pub fn serialize<S, T>(num: &T, serializer: S) -> Result<S::Ok, S::Error>
120    where
121        S: Serializer,
122        T: DecType,
123    {
124        match num.serialize() {
125            Some(value) => serializer.serialize_str(&value),
126            None => serializer.serialize_none(),
127        }
128    }
129}
130
131#[test]
132fn test_u64_dec_format() {
133    #[derive(PartialEq, Debug, serde::Deserialize, serde::Serialize)]
134    struct Test {
135        #[serde(with = "dec_format")]
136        field: u64,
137    }
138
139    assert_round_trip("{\"field\":\"42\"}", Test { field: 42 });
140    assert_round_trip("{\"field\":\"18446744073709551615\"}", Test { field: u64::MAX });
141    assert_deserialise("{\"field\":42}", Test { field: 42 });
142    assert_de_error::<Test>("{\"field\":18446744073709551616}");
143    assert_de_error::<Test>("{\"field\":\"18446744073709551616\"}");
144    assert_de_error::<Test>("{\"field\":42.0}");
145}
146
147#[test]
148fn test_u128_dec_format() {
149    #[derive(PartialEq, Debug, serde::Deserialize, serde::Serialize)]
150    struct Test {
151        #[serde(with = "dec_format")]
152        field: u128,
153    }
154
155    assert_round_trip("{\"field\":\"42\"}", Test { field: 42 });
156    assert_round_trip("{\"field\":\"18446744073709551615\"}", Test { field: u64::MAX as u128 });
157    assert_round_trip("{\"field\":\"18446744073709551616\"}", Test { field: 18446744073709551616 });
158    assert_deserialise("{\"field\":42}", Test { field: 42 });
159    assert_de_error::<Test>("{\"field\":null}");
160    assert_de_error::<Test>("{\"field\":42.0}");
161}
162
163#[test]
164fn test_option_u128_dec_format() {
165    #[derive(PartialEq, Debug, serde::Deserialize, serde::Serialize)]
166    struct Test {
167        #[serde(with = "dec_format")]
168        field: Option<u128>,
169    }
170
171    assert_round_trip("{\"field\":null}", Test { field: None });
172    assert_round_trip("{\"field\":\"42\"}", Test { field: Some(42) });
173    assert_round_trip(
174        "{\"field\":\"18446744073709551615\"}",
175        Test { field: Some(u64::MAX as u128) },
176    );
177    assert_round_trip(
178        "{\"field\":\"18446744073709551616\"}",
179        Test { field: Some(18446744073709551616) },
180    );
181    assert_deserialise("{\"field\":42}", Test { field: Some(42) });
182    assert_de_error::<Test>("{\"field\":42.0}");
183}
184
185#[cfg(test)]
186#[track_caller]
187fn assert_round_trip<'a, T>(serialised: &'a str, obj: T)
188where
189    T: serde::Deserialize<'a> + serde::Serialize + std::fmt::Debug + std::cmp::PartialEq,
190{
191    assert_eq!(serialised, serde_json::to_string(&obj).unwrap());
192    assert_eq!(obj, serde_json::from_str(serialised).unwrap());
193}
194
195#[cfg(test)]
196#[track_caller]
197fn assert_deserialise<'a, T>(serialised: &'a str, obj: T)
198where
199    T: serde::Deserialize<'a> + std::fmt::Debug + std::cmp::PartialEq,
200{
201    assert_eq!(obj, serde_json::from_str(serialised).unwrap());
202}
203
204#[cfg(test)]
205#[track_caller]
206fn assert_de_error<'a, T: serde::Deserialize<'a> + std::fmt::Debug>(serialised: &'a str) {
207    serde_json::from_str::<T>(serialised).unwrap_err();
208}