1use std::borrow::Cow;
2use std::marker::PhantomData;
3use std::path::Path;
4use std::str::FromStr;
5
6use anyhow::Result;
7use base64::prelude::{BASE64_STANDARD, Engine as _};
8use bytes::Bytes;
9use serde::de::{Error, Expected, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12pub fn load_json_from_file<T, P>(path: P) -> Result<T>
13where
14 for<'de> T: Deserialize<'de>,
15 P: AsRef<Path>,
16{
17 let data = std::fs::read_to_string(path)?;
18 let de = &mut serde_json::Deserializer::from_str(&data);
19 serde_path_to_error::deserialize(de).map_err(Into::into)
20}
21
22pub fn save_json_to_file<T, P>(value: T, path: P) -> Result<()>
23where
24 T: Serialize,
25 P: AsRef<Path>,
26{
27 let data = serde_json::to_string_pretty(&value)?;
28 std::fs::write(path, data)?;
29 Ok(())
30}
31
32pub mod socket_addr {
33 use std::net::SocketAddr;
34
35 use super::*;
36
37 pub fn serialize<S: Serializer>(value: &SocketAddr, serializer: S) -> Result<S::Ok, S::Error> {
38 if serializer.is_human_readable() {
39 serializer.collect_str(value)
40 } else {
41 value.serialize(serializer)
42 }
43 }
44
45 pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<SocketAddr, D::Error> {
46 if deserializer.is_human_readable() {
47 deserializer.deserialize_str(StrVisitor::new())
48 } else {
49 SocketAddr::deserialize(deserializer)
50 }
51 }
52}
53
54pub mod humantime {
55 use std::time::{Duration, SystemTime};
56
57 use super::*;
58
59 pub fn serialize<T, S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 for<'a> Serde<&'a T>: Serialize,
62 {
63 Serde::from(value).serialize(serializer)
64 }
65
66 pub fn deserialize<'a, T, D: Deserializer<'a>>(deserializer: D) -> Result<T, D::Error>
67 where
68 Serde<T>: Deserialize<'a>,
69 {
70 Serde::deserialize(deserializer).map(Serde::into_inner)
71 }
72
73 pub struct Serde<T>(T);
74
75 impl<T> Serde<T> {
76 #[inline]
77 pub fn into_inner(self) -> T {
78 self.0
79 }
80 }
81
82 impl<T> From<T> for Serde<T> {
83 fn from(value: T) -> Serde<T> {
84 Serde(value)
85 }
86 }
87
88 impl<'de> Deserialize<'de> for Serde<Duration> {
89 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Duration>, D::Error> {
90 struct V;
91
92 impl Visitor<'_> for V {
93 type Value = Duration;
94
95 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.write_str("a duration")
97 }
98
99 fn visit_str<E: Error>(self, v: &str) -> Result<Duration, E> {
100 ::humantime::parse_duration(v)
101 .map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
102 }
103 }
104
105 d.deserialize_str(V).map(Serde)
106 }
107 }
108
109 impl<'de> Deserialize<'de> for Serde<SystemTime> {
110 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<SystemTime>, D::Error> {
111 struct V;
112
113 impl Visitor<'_> for V {
114 type Value = SystemTime;
115
116 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.write_str("a timestamp")
118 }
119
120 fn visit_str<E: Error>(self, v: &str) -> Result<SystemTime, E> {
121 ::humantime::parse_rfc3339_weak(v)
122 .map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
123 }
124 }
125
126 d.deserialize_str(V).map(Serde)
127 }
128 }
129
130 impl<'de> Deserialize<'de> for Serde<Option<Duration>> {
131 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<Duration>>, D::Error> {
132 match Option::<Serde<Duration>>::deserialize(d)? {
133 Some(Serde(v)) => Ok(Serde(Some(v))),
134 None => Ok(Serde(None)),
135 }
136 }
137 }
138
139 impl<'de> Deserialize<'de> for Serde<Option<SystemTime>> {
140 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<SystemTime>>, D::Error> {
141 match Option::<Serde<SystemTime>>::deserialize(d)? {
142 Some(Serde(v)) => Ok(Serde(Some(v))),
143 None => Ok(Serde(None)),
144 }
145 }
146 }
147
148 impl Serialize for Serde<&Duration> {
149 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
150 serializer.collect_str(&::humantime::format_duration(*self.0))
151 }
152 }
153
154 impl Serialize for Serde<Duration> {
155 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
156 serializer.collect_str(&::humantime::format_duration(self.0))
157 }
158 }
159
160 impl Serialize for Serde<&SystemTime> {
161 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
162 serializer.collect_str(&::humantime::format_rfc3339(*self.0))
163 }
164 }
165
166 impl Serialize for Serde<SystemTime> {
167 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
168 ::humantime::format_rfc3339(self.0)
169 .to_string()
170 .serialize(serializer)
171 }
172 }
173
174 impl Serialize for Serde<&Option<Duration>> {
175 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
176 match *self.0 {
177 Some(v) => serializer.serialize_some(&Serde(v)),
178 None => serializer.serialize_none(),
179 }
180 }
181 }
182
183 impl Serialize for Serde<Option<Duration>> {
184 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
185 Serde(&self.0).serialize(serializer)
186 }
187 }
188
189 impl Serialize for Serde<&Option<SystemTime>> {
190 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
191 match *self.0 {
192 Some(v) => serializer.serialize_some(&Serde(v)),
193 None => serializer.serialize_none(),
194 }
195 }
196 }
197
198 impl Serialize for Serde<Option<SystemTime>> {
199 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
200 Serde(&self.0).serialize(serializer)
201 }
202 }
203}
204
205pub struct Base64BytesWithLimit<const LIMIT: usize>;
206
207impl<const LIMIT: usize> Base64BytesWithLimit<LIMIT> {
208 pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
209 where
210 S: serde::Serializer,
211 {
212 if serializer.is_human_readable() {
213 let base64 = BASE64_STANDARD.encode(value);
214 serializer.serialize_str(&base64)
215 } else {
216 serializer.serialize_bytes(value)
217 }
218 }
219
220 pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
221 where
222 D: serde::Deserializer<'de>,
223 {
224 struct BytesVisitorWithLimit<const LIMIT: usize>;
225
226 impl<'de, const LIMIT: usize> Visitor<'de> for BytesVisitorWithLimit<LIMIT> {
227 type Value = Bytes;
228
229 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 formatter.write_str("byte array")
231 }
232
233 #[inline]
234 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
235 where
236 V: serde::de::SeqAccess<'de>,
237 {
238 'valid: {
239 let hint = seq.size_hint().unwrap_or(0);
240 if hint > LIMIT {
241 break 'valid;
242 }
243
244 let len = std::cmp::min(hint, 4096);
245 let mut values: Vec<u8> = Vec::with_capacity(len);
246
247 while let Some(value) = seq.next_element()? {
248 if values.len() > LIMIT {
249 break 'valid;
250 }
251
252 values.push(value);
253 }
254
255 return Ok(Bytes::from(values));
256 }
257
258 Err(Error::custom("slice is too big"))
259 }
260
261 #[inline]
262 fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
263 if v.len() > LIMIT {
264 return Err(Error::custom("slice is too big"));
265 }
266 Ok(Bytes::copy_from_slice(v))
267 }
268
269 #[inline]
270 fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
271 if v.len() > LIMIT {
272 return Err(Error::custom("slice is too big"));
273 }
274 Ok(Bytes::from(v))
275 }
276 }
277
278 if deserializer.is_human_readable() {
279 let BorrowedStr(s) = <_>::deserialize(deserializer)?;
280 if base64::decoded_len_estimate(s.len()) >= LIMIT {
281 return Err(Error::custom("slice is too big"));
282 }
283
284 let v = BASE64_STANDARD
285 .decode(s.as_ref())
286 .map_err(|_e| D::Error::custom("invalid base64"))?;
287
288 Ok(Bytes::from(v))
289 } else {
290 deserializer.deserialize_bytes(BytesVisitorWithLimit::<LIMIT>)
291 }
292 }
293}
294
295pub mod string {
296 use super::*;
297
298 pub fn serialize<S>(value: &dyn std::fmt::Display, serializer: S) -> Result<S::Ok, S::Error>
299 where
300 S: serde::Serializer,
301 {
302 serializer.collect_str(value)
303 }
304
305 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
306 where
307 D: serde::Deserializer<'de>,
308 T: FromStr,
309 T::Err: std::fmt::Display,
310 {
311 BorrowedStr::deserialize(deserializer)
312 .and_then(|data| T::from_str(&data.0).map_err(D::Error::custom))
313 }
314}
315
316pub mod option_string {
317 use super::*;
318
319 pub fn serialize<S, T>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
320 where
321 S: serde::Serializer,
322 T: std::fmt::Display,
323 {
324 #[derive(Serialize)]
325 #[serde(transparent)]
326 #[repr(transparent)]
327 struct Helper<'a, T: std::fmt::Display>(#[serde(with = "string")] &'a T);
328
329 value.as_ref().map(Helper).serialize(serializer)
330 }
331
332 pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
333 where
334 D: serde::Deserializer<'de>,
335 T: FromStr,
336 T::Err: std::fmt::Display,
337 {
338 #[derive(Deserialize)]
339 #[serde(transparent)]
340 #[repr(transparent)]
341 struct Helper<T>(#[serde(with = "string")] T)
342 where
343 T: FromStr,
344 T::Err: std::fmt::Display;
345
346 Option::<Helper<T>>::deserialize(deserializer).map(|x| x.map(|Helper(x)| x))
347 }
348}
349
350pub mod signature {
351 use base64::engine::Engine as _;
352 use base64::prelude::BASE64_STANDARD;
353
354 use super::*;
355
356 pub fn serialize<S>(data: &[u8; 64], serializer: S) -> Result<S::Ok, S::Error>
357 where
358 S: serde::Serializer,
359 {
360 if serializer.is_human_readable() {
361 serializer.serialize_str(&BASE64_STANDARD.encode(data))
362 } else {
363 data.serialize(serializer)
364 }
365 }
366
367 pub fn deserialize<'de, D>(deserializer: D) -> Result<Box<[u8; 64]>, D::Error>
368 where
369 D: serde::Deserializer<'de>,
370 {
371 use serde::de::Error;
372
373 if deserializer.is_human_readable() {
374 <BorrowedStr<'_> as Deserialize>::deserialize(deserializer).and_then(
375 |BorrowedStr(s)| {
376 let mut buffer = [0u8; 66];
377 match BASE64_STANDARD.decode_slice(s.as_ref(), &mut buffer) {
378 Ok(64) => {
379 let [data @ .., _, _] = buffer;
380 Ok(Box::new(data))
381 }
382 _ => Err(Error::custom("Invalid signature")),
383 }
384 },
385 )
386 } else {
387 deserializer
388 .deserialize_bytes(BytesVisitor::<64>)
389 .map(Box::new)
390 }
391 }
392}
393
394#[derive(Deserialize)]
395#[repr(transparent)]
396pub struct BorrowedStr<'a>(#[serde(borrow)] pub Cow<'a, str>);
397
398pub struct StrVisitor<S>(PhantomData<S>);
399
400impl<S> StrVisitor<S> {
401 pub const fn new() -> Self {
402 Self(PhantomData)
403 }
404}
405
406impl<S> Default for StrVisitor<S> {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412impl<S: FromStr> Visitor<'_> for StrVisitor<S>
413where
414 <S as FromStr>::Err: std::fmt::Display,
415{
416 type Value = S;
417
418 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 write!(f, "a string")
420 }
421
422 fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
423 value.parse::<Self::Value>().map_err(Error::custom)
424 }
425}
426
427pub struct BytesVisitor<const M: usize>;
428
429impl<'de, const M: usize> Visitor<'de> for BytesVisitor<M> {
430 type Value = [u8; M];
431
432 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 f.write_fmt(format_args!("a byte array of size {M}"))
434 }
435
436 fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
437 v.try_into()
438 .map_err(|_e| Error::invalid_length(v.len(), &self))
439 }
440
441 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
442 where
443 A: serde::de::SeqAccess<'de>,
444 {
445 struct SeqIter<'de, A, T> {
446 access: A,
447 marker: PhantomData<(&'de (), T)>,
448 }
449
450 impl<'de, A, T> SeqIter<'de, A, T> {
451 pub(crate) fn new(access: A) -> Self
452 where
453 A: serde::de::SeqAccess<'de>,
454 {
455 Self {
456 access,
457 marker: PhantomData,
458 }
459 }
460 }
461
462 impl<'de, A, T> Iterator for SeqIter<'de, A, T>
463 where
464 A: serde::de::SeqAccess<'de>,
465 T: Deserialize<'de>,
466 {
467 type Item = Result<T, A::Error>;
468
469 fn next(&mut self) -> Option<Self::Item> {
470 self.access.next_element().transpose()
471 }
472
473 fn size_hint(&self) -> (usize, Option<usize>) {
474 match self.access.size_hint() {
475 Some(size) => (size, Some(size)),
476 None => (0, None),
477 }
478 }
479 }
480
481 fn array_from_iterator<I, T, E, const N: usize>(
482 mut iter: I,
483 expected: &dyn Expected,
484 ) -> Result<[T; N], E>
485 where
486 I: Iterator<Item = Result<T, E>>,
487 E: Error,
488 {
489 use core::mem::MaybeUninit;
490
491 unsafe fn drop_array_elems<T, const N: usize>(
495 num: usize,
496 mut arr: [MaybeUninit<T>; N],
497 ) {
498 arr[..num]
499 .iter_mut()
500 .for_each(|item| unsafe { item.assume_init_drop() });
501 }
502
503 let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };
505
506 for (i, elem) in arr[..].iter_mut().enumerate() {
508 *elem = match iter.next() {
509 Some(Ok(value)) => MaybeUninit::new(value),
510 Some(Err(err)) => {
511 unsafe { drop_array_elems(i, arr) };
513 return Err(err);
514 }
515 None => {
516 unsafe { drop_array_elems(i, arr) };
518 return Err(Error::invalid_length(i, expected));
519 }
520 };
521 }
522
523 Ok(unsafe { std::mem::transmute_copy(&arr) })
527 }
528
529 array_from_iterator(SeqIter::new(seq), &self)
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn struct_with_option_string() {
539 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
540 struct Test {
541 #[serde(with = "option_string")]
542 value: Option<u64>,
543 }
544
545 for value in [Test { value: None }, Test { value: Some(123) }, Test {
546 value: Some(u64::MAX),
547 }] {
548 let test = serde_json::to_string(&value).unwrap();
549 println!("{test}");
550 let parsed: Test = serde_json::from_str(&test).unwrap();
551 assert_eq!(value, parsed);
552 }
553 }
554}