1use core::fmt;
2use serde::{
3 de::{self, Unexpected, Visitor},
4 Deserializer, Serializer,
5};
6
7pub struct LitStr<'a>(pub &'a str);
10
11impl<'a, 'de> Visitor<'de> for LitStr<'a> {
12 type Value = ();
13
14 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
15 write!(formatter, "the lit {}", self.0)
16 }
17
18 fn visit_str<E>(self, s: &str) -> Result<(), E>
19 where
20 E: de::Error,
21 {
22 if s == self.0 {
23 Ok(())
24 } else {
25 Err(de::Error::invalid_value(Unexpected::Str(s), &self))
26 }
27 }
28}
29
30pub struct LitFloat(pub f64);
31
32impl<'de> Visitor<'de> for LitFloat {
33 type Value = ();
34
35 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
36 write!(formatter, "the lit {}", self.0)
37 }
38
39 fn visit_f64<E>(self, v: f64) -> Result<(), E>
40 where
41 E: de::Error,
42 {
43 if v == self.0 {
44 Ok(())
45 } else {
46 Err(de::Error::invalid_value(Unexpected::Float(v), &self))
47 }
48 }
49}
50
51pub struct LitInt<const N: i64>;
52
53impl<const N: i64> LitInt<N> {
54 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
55 deserializer.deserialize_any(Self)
56 }
57
58 pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
59 serializer.serialize_i64(N)
60 }
61}
62
63impl<'de, const N: i64> Visitor<'de> for LitInt<N> {
64 type Value = ();
65
66 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
67 write!(formatter, "the lit {}", N)
68 }
69
70 fn visit_i64<E>(self, v: i64) -> Result<(), E>
71 where
72 E: de::Error,
73 {
74 if v == N {
75 Ok(())
76 } else {
77 Err(de::Error::invalid_value(Unexpected::Signed(v), &self))
78 }
79 }
80
81 fn visit_u64<E>(self, v: u64) -> Result<(), E>
82 where
83 E: de::Error,
84 {
85 if v as i64 == N {
86 Ok(())
87 } else {
88 Err(de::Error::invalid_value(Unexpected::Unsigned(v), &self))
89 }
90 }
91}
92
93pub struct LitBool<const B: bool>;
94
95impl<const B: bool> LitBool<B> {
96 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
97 deserializer.deserialize_bool(Self)
98 }
99
100 pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
101 serializer.serialize_bool(B)
102 }
103}
104
105impl<'de, const B: bool> Visitor<'de> for LitBool<B> {
106 type Value = ();
107
108 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
109 write!(formatter, "the lit {}", B)
110 }
111
112 fn visit_bool<E>(self, v: bool) -> Result<(), E>
113 where
114 E: de::Error,
115 {
116 if v == B {
117 Ok(())
118 } else {
119 Err(de::Error::invalid_value(Unexpected::Bool(v), &self))
120 }
121 }
122}
123
124pub struct LitChar<const C: char>;
125
126impl<const C: char> LitChar<C> {
127 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<(), D::Error> {
128 deserializer.deserialize_char(Self)
129 }
130
131 pub fn serialize<S: Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
132 serializer.serialize_char(C)
133 }
134}
135
136impl<'de, const C: char> Visitor<'de> for LitChar<C> {
137 type Value = ();
138
139 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
140 write!(formatter, "the lit {}", C)
141 }
142
143 fn visit_str<E>(self, v: &str) -> Result<(), E>
144 where
145 E: de::Error,
146 {
147 if v.starts_with(C) {
148 Ok(())
149 } else {
150 Err(de::Error::invalid_value(Unexpected::Str(v), &self))
151 }
152 }
153}
154
155#[macro_export]
156macro_rules! lit_str {
157 ($struct_name:ident, $val:expr) => {
158 pub struct $struct_name;
159
160 impl $struct_name {
161 pub fn deserialize<'de, D: serde::Deserializer<'de>>(
162 deserializer: D,
163 ) -> Result<(), D::Error> {
164 deserializer.deserialize_str($crate::LitStr($val))
165 }
166
167 pub fn serialize<S: serde::Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
168 serializer.serialize_str($val)
169 }
170 }
171 };
172}
173
174#[macro_export]
175macro_rules! lit_float {
176 ($struct_name:ident, $val:expr) => {
177 pub struct $struct_name;
178
179 impl $struct_name {
180 pub fn deserialize<'de, D: serde::Deserializer<'de>>(
181 deserializer: D,
182 ) -> Result<(), D::Error> {
183 deserializer.deserialize_f64($crate::LitFloat($val as f64))
184 }
185
186 pub fn serialize<S: serde::Serializer>(serializer: S) -> Result<S::Ok, S::Error> {
187 serializer.serialize_f64($val as f64)
188 }
189 }
190 };
191}
192
193#[cfg(test)]
194mod test {
195 use super::*;
196 use serde::{Deserialize, Serialize};
197
198 lit_str!(LitAuto, "auto");
199 lit_str!(LitBlah, "blah");
200 lit_float!(Lit3_1, 3.1);
201
202 #[derive(Debug, Serialize, Deserialize, PartialEq)]
203 #[serde(untagged)]
204 enum Items {
205 #[serde(with = "LitAuto")]
206 Auto,
207 #[serde(with = "LitBlah")]
208 Blah,
209 #[serde(with = "LitInt::<123>")]
210 Num123,
211 #[serde(with = "Lit3_1")]
212 Num3Dot1,
213 Number(f64),
214 #[serde(with = "LitBool::<true>")]
215 True,
216 #[serde(with = "LitBool::<false>")]
217 False,
218 #[serde(with = "LitChar::<'z'>")]
219 SingleChar,
220 }
221
222 #[test]
223 fn test_serde() {
224 assert_eq!(
225 serde_json::to_string_pretty(&Items::Number(4.5)).unwrap(),
226 "4.5"
227 );
228 assert_eq!(
229 serde_json::to_string_pretty(&Items::Auto).unwrap(),
230 "\"auto\""
231 );
232 assert_eq!(
233 serde_json::to_string_pretty(&Items::Blah).unwrap(),
234 "\"blah\""
235 );
236 assert_eq!(serde_json::to_string_pretty(&Items::Num123).unwrap(), "123");
237 assert_eq!(
238 serde_json::to_string_pretty(&Items::Num3Dot1).unwrap(),
239 "3.1"
240 );
241 assert_eq!(serde_json::to_string_pretty(&Items::True).unwrap(), "true");
242 assert_eq!(
243 serde_json::to_string_pretty(&Items::False).unwrap(),
244 "false"
245 );
246 assert_eq!(
247 serde_json::to_string_pretty(&Items::SingleChar).unwrap(),
248 "\"z\""
249 );
250
251 assert_eq!(
252 serde_json::from_str::<Items>("2.3").unwrap(),
253 Items::Number(2.3)
254 );
255 assert_eq!(
256 serde_json::from_str::<Items>("\"auto\"").unwrap(),
257 Items::Auto
258 );
259 assert_eq!(
260 serde_json::from_str::<Items>("\"blah\"").unwrap(),
261 Items::Blah
262 );
263 assert_eq!(serde_json::from_str::<Items>("123").unwrap(), Items::Num123);
264 assert_eq!(
265 serde_json::from_str::<Items>("3.1").unwrap(),
266 Items::Num3Dot1
267 );
268 assert_eq!(serde_json::from_str::<Items>("true").unwrap(), Items::True);
269 assert_eq!(
270 serde_json::from_str::<Items>("false").unwrap(),
271 Items::False
272 );
273 assert_eq!(
274 serde_json::from_str::<Items>("\"z\"").unwrap(),
275 Items::SingleChar
276 );
277 }
278}