Skip to main content

serde_literals/
lib.rs

1use core::fmt;
2use serde::{
3    de::{self, Unexpected, Visitor},
4    Deserializer, Serializer,
5};
6
7// serde_literals
8// deserialise and serialise literal strings, ints, floats, bools and chars into enum unit variants
9pub struct LitStr<'a>(pub &'a str);
10
11impl<'a, 'de> Visitor<'de> for LitStr<'a> {
12    type Value = ();
13
14    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
15        write!(formatter, "the lit {}", self.0)
16    }
17
18    fn visit_str<E>(self, s: &str) -> Result<(), E>
19    where
20        E: de::Error,
21    {
22        if s == self.0 {
23            Ok(())
24        } else {
25            Err(de::Error::invalid_value(Unexpected::Str(s), &self))
26        }
27    }
28}
29
30pub struct LitFloat(pub f64);
31
32impl<'de> Visitor<'de> for LitFloat {
33    type Value = ();
34
35    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
36        write!(formatter, "the lit {}", self.0)
37    }
38
39    fn visit_f64<E>(self, v: f64) -> Result<(), E>
40    where
41        E: de::Error,
42    {
43        if v == self.0 {
44            Ok(())
45        } else {
46            Err(de::Error::invalid_value(Unexpected::Float(v), &self))
47        }
48    }
49}
50
51pub struct LitInt<const N: i64>;
52
53impl<const N: i64> LitInt<N> {
54    pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
55        deserializer.deserialize_any(Self)
56    }
57
58    pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
59        serializer.serialize_i64(N)
60    }
61}
62
63impl<'de, const N: i64> Visitor<'de> for LitInt<N> {
64    type Value = ();
65
66    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
67        write!(formatter, "the lit {}", N)
68    }
69
70    fn visit_i64<E>(self, v: i64) -> Result<(), E>
71    where
72        E: de::Error,
73    {
74        if v == N {
75            Ok(())
76        } else {
77            Err(de::Error::invalid_value(Unexpected::Signed(v), &self))
78        }
79    }
80
81    fn visit_u64<E>(self, v: u64) -> Result<(), E>
82    where
83        E: de::Error,
84    {
85        if v as i64 == N {
86            Ok(())
87        } else {
88            Err(de::Error::invalid_value(Unexpected::Unsigned(v), &self))
89        }
90    }
91}
92
93pub struct LitBool<const B: bool>;
94
95impl<const B: bool> LitBool<B> {
96    pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
97        deserializer.deserialize_bool(Self)
98    }
99
100    pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
101        serializer.serialize_bool(B)
102    }
103}
104
105impl<'de, const B: bool> Visitor<'de> for LitBool<B> {
106    type Value = ();
107
108    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
109        write!(formatter, "the lit {}", B)
110    }
111
112    fn visit_bool<E>(self, v: bool) -> Result<(), E>
113    where
114        E: de::Error,
115    {
116        if v == B {
117            Ok(())
118        } else {
119            Err(de::Error::invalid_value(Unexpected::Bool(v), &self))
120        }
121    }
122}
123
124pub struct LitChar<const C: char>;
125
126impl<const C: char> LitChar<C> {
127    pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
128        deserializer.deserialize_char(Self)
129    }
130
131    pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
132        serializer.serialize_char(C)
133    }
134}
135
136impl<'de, const C: char> Visitor<'de> for LitChar<C> {
137    type Value = ();
138
139    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
140        write!(formatter, "the lit {}", C)
141    }
142
143    fn visit_str<E>(self, v: &str) -> Result<(), E>
144    where
145        E: de::Error,
146    {
147        if v.starts_with(C) {
148            Ok(())
149        } else {
150            Err(de::Error::invalid_value(Unexpected::Str(v), &self))
151        }
152    }
153}
154
155#[macro_export]
156macro_rules! lit_str {
157    ($struct_name:ident, $val:expr) => {
158        pub struct $struct_name;
159
160        impl $struct_name {
161            pub fn deserialize<'de, D: serde::Deserializer<'de>>(
162                deserializer: D,
163            ) -> ::core::result::Result<(), D::Error> {
164                deserializer.deserialize_str($crate::LitStr($val))
165            }
166
167            pub fn serialize<S: serde::Serializer>(serializer: S) -> ::core::result::Result<S::Ok, S::Error> {
168                serializer.serialize_str($val)
169            }
170        }
171    };
172}
173
174#[macro_export]
175macro_rules! lit_float {
176    ($struct_name:ident, $val:expr) => {
177        pub struct $struct_name;
178
179        impl $struct_name {
180            pub fn deserialize<'de, D: serde::Deserializer<'de>>(
181                deserializer: D,
182            ) -> ::core::result::Result<(), D::Error> {
183                deserializer.deserialize_f64($crate::LitFloat($val as f64))
184            }
185
186            pub fn serialize<S: serde::Serializer>(serializer: S) -> ::core::result::Result<S::Ok, S::Error> {
187                serializer.serialize_f64($val as f64)
188            }
189        }
190    };
191}
192
193#[cfg(test)]
194mod test {
195    use super::*;
196    use serde::{Deserialize, Serialize};
197
198    lit_str!(LitAuto, "auto");
199    lit_str!(LitBlah, "blah");
200    lit_float!(Lit3_1, 3.1);
201
202    #[derive(Debug, Serialize, Deserialize, PartialEq)]
203    #[serde(untagged)]
204    enum Items {
205        #[serde(with = "LitAuto")]
206        Auto,
207        #[serde(with = "LitBlah")]
208        Blah,
209        #[serde(with = "LitInt::<123>")]
210        Num123,
211        #[serde(with = "Lit3_1")]
212        Num3Dot1,
213        Number(f64),
214        #[serde(with = "LitBool::<true>")]
215        True,
216        #[serde(with = "LitBool::<false>")]
217        False,
218        #[serde(with = "LitChar::<'z'>")]
219        SingleChar,
220    }
221
222    #[test]
223    fn test_serde() {
224        assert_eq!(
225            serde_json::to_string_pretty(&Items::Number(4.5)).unwrap(),
226            "4.5"
227        );
228        assert_eq!(
229            serde_json::to_string_pretty(&Items::Auto).unwrap(),
230            "\"auto\""
231        );
232        assert_eq!(
233            serde_json::to_string_pretty(&Items::Blah).unwrap(),
234            "\"blah\""
235        );
236        assert_eq!(serde_json::to_string_pretty(&Items::Num123).unwrap(), "123");
237        assert_eq!(
238            serde_json::to_string_pretty(&Items::Num3Dot1).unwrap(),
239            "3.1"
240        );
241        assert_eq!(serde_json::to_string_pretty(&Items::True).unwrap(), "true");
242        assert_eq!(
243            serde_json::to_string_pretty(&Items::False).unwrap(),
244            "false"
245        );
246        assert_eq!(
247            serde_json::to_string_pretty(&Items::SingleChar).unwrap(),
248            "\"z\""
249        );
250
251        assert_eq!(
252            serde_json::from_str::<Items>("2.3").unwrap(),
253            Items::Number(2.3)
254        );
255        assert_eq!(
256            serde_json::from_str::<Items>("\"auto\"").unwrap(),
257            Items::Auto
258        );
259        assert_eq!(
260            serde_json::from_str::<Items>("\"blah\"").unwrap(),
261            Items::Blah
262        );
263        assert_eq!(serde_json::from_str::<Items>("123").unwrap(), Items::Num123);
264        assert_eq!(
265            serde_json::from_str::<Items>("3.1").unwrap(),
266            Items::Num3Dot1
267        );
268        assert_eq!(serde_json::from_str::<Items>("true").unwrap(), Items::True);
269        assert_eq!(
270            serde_json::from_str::<Items>("false").unwrap(),
271            Items::False
272        );
273        assert_eq!(
274            serde_json::from_str::<Items>("\"z\"").unwrap(),
275            Items::SingleChar
276        );
277    }
278}