tycho_util/
serde_helpers.rs

1use std::borrow::Cow;
2use std::marker::PhantomData;
3use std::path::Path;
4use std::str::FromStr;
5
6use anyhow::Result;
7use base64::prelude::{BASE64_STANDARD, Engine as _};
8use bytes::Bytes;
9use serde::de::{Error, Expected, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12pub fn load_json_from_file<T, P>(path: P) -> Result<T>
13where
14    for<'de> T: Deserialize<'de>,
15    P: AsRef<Path>,
16{
17    let data = std::fs::read_to_string(path)?;
18    let de = &mut serde_json::Deserializer::from_str(&data);
19    serde_path_to_error::deserialize(de).map_err(Into::into)
20}
21
22pub fn save_json_to_file<T, P>(value: T, path: P) -> Result<()>
23where
24    T: Serialize,
25    P: AsRef<Path>,
26{
27    let data = serde_json::to_string_pretty(&value)?;
28    std::fs::write(path, data)?;
29    Ok(())
30}
31
32pub mod socket_addr {
33    use std::net::SocketAddr;
34
35    use super::*;
36
37    pub fn serialize<S: Serializer>(value: &SocketAddr, serializer: S) -> Result<S::Ok, S::Error> {
38        if serializer.is_human_readable() {
39            serializer.collect_str(value)
40        } else {
41            value.serialize(serializer)
42        }
43    }
44
45    pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<SocketAddr, D::Error> {
46        if deserializer.is_human_readable() {
47            deserializer.deserialize_str(StrVisitor::new())
48        } else {
49            SocketAddr::deserialize(deserializer)
50        }
51    }
52}
53
54pub mod humantime {
55    use std::time::{Duration, SystemTime};
56
57    use super::*;
58
59    pub fn serialize<T, S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
60    where
61        for<'a> Serde<&'a T>: Serialize,
62    {
63        Serde::from(value).serialize(serializer)
64    }
65
66    pub fn deserialize<'a, T, D: Deserializer<'a>>(deserializer: D) -> Result<T, D::Error>
67    where
68        Serde<T>: Deserialize<'a>,
69    {
70        Serde::deserialize(deserializer).map(Serde::into_inner)
71    }
72
73    pub struct Serde<T>(T);
74
75    impl<T> Serde<T> {
76        #[inline]
77        pub fn into_inner(self) -> T {
78            self.0
79        }
80    }
81
82    impl<T> From<T> for Serde<T> {
83        fn from(value: T) -> Serde<T> {
84            Serde(value)
85        }
86    }
87
88    impl<'de> Deserialize<'de> for Serde<Duration> {
89        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Duration>, D::Error> {
90            struct V;
91
92            impl Visitor<'_> for V {
93                type Value = Duration;
94
95                fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96                    f.write_str("a duration")
97                }
98
99                fn visit_str<E: Error>(self, v: &str) -> Result<Duration, E> {
100                    ::humantime::parse_duration(v)
101                        .map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
102                }
103            }
104
105            d.deserialize_str(V).map(Serde)
106        }
107    }
108
109    impl<'de> Deserialize<'de> for Serde<SystemTime> {
110        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<SystemTime>, D::Error> {
111            struct V;
112
113            impl Visitor<'_> for V {
114                type Value = SystemTime;
115
116                fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117                    f.write_str("a timestamp")
118                }
119
120                fn visit_str<E: Error>(self, v: &str) -> Result<SystemTime, E> {
121                    ::humantime::parse_rfc3339_weak(v)
122                        .map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
123                }
124            }
125
126            d.deserialize_str(V).map(Serde)
127        }
128    }
129
130    impl<'de> Deserialize<'de> for Serde<Option<Duration>> {
131        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<Duration>>, D::Error> {
132            match Option::<Serde<Duration>>::deserialize(d)? {
133                Some(Serde(v)) => Ok(Serde(Some(v))),
134                None => Ok(Serde(None)),
135            }
136        }
137    }
138
139    impl<'de> Deserialize<'de> for Serde<Option<SystemTime>> {
140        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<SystemTime>>, D::Error> {
141            match Option::<Serde<SystemTime>>::deserialize(d)? {
142                Some(Serde(v)) => Ok(Serde(Some(v))),
143                None => Ok(Serde(None)),
144            }
145        }
146    }
147
148    impl Serialize for Serde<&Duration> {
149        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
150            serializer.collect_str(&::humantime::format_duration(*self.0))
151        }
152    }
153
154    impl Serialize for Serde<Duration> {
155        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
156            serializer.collect_str(&::humantime::format_duration(self.0))
157        }
158    }
159
160    impl Serialize for Serde<&SystemTime> {
161        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
162            serializer.collect_str(&::humantime::format_rfc3339(*self.0))
163        }
164    }
165
166    impl Serialize for Serde<SystemTime> {
167        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
168            ::humantime::format_rfc3339(self.0)
169                .to_string()
170                .serialize(serializer)
171        }
172    }
173
174    impl Serialize for Serde<&Option<Duration>> {
175        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
176            match *self.0 {
177                Some(v) => serializer.serialize_some(&Serde(v)),
178                None => serializer.serialize_none(),
179            }
180        }
181    }
182
183    impl Serialize for Serde<Option<Duration>> {
184        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
185            Serde(&self.0).serialize(serializer)
186        }
187    }
188
189    impl Serialize for Serde<&Option<SystemTime>> {
190        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
191            match *self.0 {
192                Some(v) => serializer.serialize_some(&Serde(v)),
193                None => serializer.serialize_none(),
194            }
195        }
196    }
197
198    impl Serialize for Serde<Option<SystemTime>> {
199        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
200            Serde(&self.0).serialize(serializer)
201        }
202    }
203}
204
205pub struct Base64BytesWithLimit<const LIMIT: usize>;
206
207impl<const LIMIT: usize> Base64BytesWithLimit<LIMIT> {
208    pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
209    where
210        S: serde::Serializer,
211    {
212        if serializer.is_human_readable() {
213            let base64 = BASE64_STANDARD.encode(value);
214            serializer.serialize_str(&base64)
215        } else {
216            serializer.serialize_bytes(value)
217        }
218    }
219
220    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
221    where
222        D: serde::Deserializer<'de>,
223    {
224        struct BytesVisitorWithLimit<const LIMIT: usize>;
225
226        impl<'de, const LIMIT: usize> Visitor<'de> for BytesVisitorWithLimit<LIMIT> {
227            type Value = Bytes;
228
229            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230                formatter.write_str("byte array")
231            }
232
233            #[inline]
234            fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
235            where
236                V: serde::de::SeqAccess<'de>,
237            {
238                'valid: {
239                    let hint = seq.size_hint().unwrap_or(0);
240                    if hint > LIMIT {
241                        break 'valid;
242                    }
243
244                    let len = std::cmp::min(hint, 4096);
245                    let mut values: Vec<u8> = Vec::with_capacity(len);
246
247                    while let Some(value) = seq.next_element()? {
248                        if values.len() > LIMIT {
249                            break 'valid;
250                        }
251
252                        values.push(value);
253                    }
254
255                    return Ok(Bytes::from(values));
256                }
257
258                Err(Error::custom("slice is too big"))
259            }
260
261            #[inline]
262            fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
263                if v.len() > LIMIT {
264                    return Err(Error::custom("slice is too big"));
265                }
266                Ok(Bytes::copy_from_slice(v))
267            }
268
269            #[inline]
270            fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
271                if v.len() > LIMIT {
272                    return Err(Error::custom("slice is too big"));
273                }
274                Ok(Bytes::from(v))
275            }
276        }
277
278        if deserializer.is_human_readable() {
279            let BorrowedStr(s) = <_>::deserialize(deserializer)?;
280            if base64::decoded_len_estimate(s.len()) >= LIMIT {
281                return Err(Error::custom("slice is too big"));
282            }
283
284            let v = BASE64_STANDARD
285                .decode(s.as_ref())
286                .map_err(|_e| D::Error::custom("invalid base64"))?;
287
288            Ok(Bytes::from(v))
289        } else {
290            deserializer.deserialize_bytes(BytesVisitorWithLimit::<LIMIT>)
291        }
292    }
293}
294
295pub mod string {
296    use super::*;
297
298    pub fn serialize<S>(value: &dyn std::fmt::Display, serializer: S) -> Result<S::Ok, S::Error>
299    where
300        S: serde::Serializer,
301    {
302        serializer.collect_str(value)
303    }
304
305    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
306    where
307        D: serde::Deserializer<'de>,
308        T: FromStr,
309        T::Err: std::fmt::Display,
310    {
311        BorrowedStr::deserialize(deserializer)
312            .and_then(|data| T::from_str(&data.0).map_err(D::Error::custom))
313    }
314}
315
316pub mod option_string {
317    use super::*;
318
319    pub fn serialize<S, T>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
320    where
321        S: serde::Serializer,
322        T: std::fmt::Display,
323    {
324        #[derive(Serialize)]
325        #[serde(transparent)]
326        #[repr(transparent)]
327        struct Helper<'a, T: std::fmt::Display>(#[serde(with = "string")] &'a T);
328
329        value.as_ref().map(Helper).serialize(serializer)
330    }
331
332    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
333    where
334        D: serde::Deserializer<'de>,
335        T: FromStr,
336        T::Err: std::fmt::Display,
337    {
338        #[derive(Deserialize)]
339        #[serde(transparent)]
340        #[repr(transparent)]
341        struct Helper<T>(#[serde(with = "string")] T)
342        where
343            T: FromStr,
344            T::Err: std::fmt::Display;
345
346        Option::<Helper<T>>::deserialize(deserializer).map(|x| x.map(|Helper(x)| x))
347    }
348}
349
350pub mod signature {
351    use base64::engine::Engine as _;
352    use base64::prelude::BASE64_STANDARD;
353
354    use super::*;
355
356    pub fn serialize<S>(data: &[u8; 64], serializer: S) -> Result<S::Ok, S::Error>
357    where
358        S: serde::Serializer,
359    {
360        if serializer.is_human_readable() {
361            serializer.serialize_str(&BASE64_STANDARD.encode(data))
362        } else {
363            data.serialize(serializer)
364        }
365    }
366
367    pub fn deserialize<'de, D>(deserializer: D) -> Result<Box<[u8; 64]>, D::Error>
368    where
369        D: serde::Deserializer<'de>,
370    {
371        use serde::de::Error;
372
373        if deserializer.is_human_readable() {
374            <BorrowedStr<'_> as Deserialize>::deserialize(deserializer).and_then(
375                |BorrowedStr(s)| {
376                    let mut buffer = [0u8; 66];
377                    match BASE64_STANDARD.decode_slice(s.as_ref(), &mut buffer) {
378                        Ok(64) => {
379                            let [data @ .., _, _] = buffer;
380                            Ok(Box::new(data))
381                        }
382                        _ => Err(Error::custom("Invalid signature")),
383                    }
384                },
385            )
386        } else {
387            deserializer
388                .deserialize_bytes(BytesVisitor::<64>)
389                .map(Box::new)
390        }
391    }
392}
393
394#[derive(Deserialize)]
395#[repr(transparent)]
396pub struct BorrowedStr<'a>(#[serde(borrow)] pub Cow<'a, str>);
397
398pub struct StrVisitor<S>(PhantomData<S>);
399
400impl<S> StrVisitor<S> {
401    pub const fn new() -> Self {
402        Self(PhantomData)
403    }
404}
405
406impl<S> Default for StrVisitor<S> {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412impl<S: FromStr> Visitor<'_> for StrVisitor<S>
413where
414    <S as FromStr>::Err: std::fmt::Display,
415{
416    type Value = S;
417
418    fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419        write!(f, "a string")
420    }
421
422    fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
423        value.parse::<Self::Value>().map_err(Error::custom)
424    }
425}
426
427pub struct BytesVisitor<const M: usize>;
428
429impl<'de, const M: usize> Visitor<'de> for BytesVisitor<M> {
430    type Value = [u8; M];
431
432    fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        f.write_fmt(format_args!("a byte array of size {M}"))
434    }
435
436    fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
437        v.try_into()
438            .map_err(|_e| Error::invalid_length(v.len(), &self))
439    }
440
441    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
442    where
443        A: serde::de::SeqAccess<'de>,
444    {
445        struct SeqIter<'de, A, T> {
446            access: A,
447            marker: PhantomData<(&'de (), T)>,
448        }
449
450        impl<'de, A, T> SeqIter<'de, A, T> {
451            pub(crate) fn new(access: A) -> Self
452            where
453                A: serde::de::SeqAccess<'de>,
454            {
455                Self {
456                    access,
457                    marker: PhantomData,
458                }
459            }
460        }
461
462        impl<'de, A, T> Iterator for SeqIter<'de, A, T>
463        where
464            A: serde::de::SeqAccess<'de>,
465            T: Deserialize<'de>,
466        {
467            type Item = Result<T, A::Error>;
468
469            fn next(&mut self) -> Option<Self::Item> {
470                self.access.next_element().transpose()
471            }
472
473            fn size_hint(&self) -> (usize, Option<usize>) {
474                match self.access.size_hint() {
475                    Some(size) => (size, Some(size)),
476                    None => (0, None),
477                }
478            }
479        }
480
481        fn array_from_iterator<I, T, E, const N: usize>(
482            mut iter: I,
483            expected: &dyn Expected,
484        ) -> Result<[T; N], E>
485        where
486            I: Iterator<Item = Result<T, E>>,
487            E: Error,
488        {
489            use core::mem::MaybeUninit;
490
491            /// # Safety
492            /// The following must be true:
493            /// - The first `num` elements must be initialized.
494            unsafe fn drop_array_elems<T, const N: usize>(
495                num: usize,
496                mut arr: [MaybeUninit<T>; N],
497            ) {
498                arr[..num]
499                    .iter_mut()
500                    .for_each(|item| unsafe { item.assume_init_drop() });
501            }
502
503            // SAFETY: It is safe to assume that array of uninitialized values is initialized itself.
504            let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };
505
506            // NOTE: Leaks memory on panic
507            for (i, elem) in arr[..].iter_mut().enumerate() {
508                *elem = match iter.next() {
509                    Some(Ok(value)) => MaybeUninit::new(value),
510                    Some(Err(err)) => {
511                        // SAFETY: Items until `i` were initialized.
512                        unsafe { drop_array_elems(i, arr) };
513                        return Err(err);
514                    }
515                    None => {
516                        // SAFETY: Items until `i` were initialized.
517                        unsafe { drop_array_elems(i, arr) };
518                        return Err(Error::invalid_length(i, expected));
519                    }
520                };
521            }
522
523            // Everything is initialized. Transmute the array to the initialized type.
524            // A normal transmute is not possible because of:
525            // https://github.com/rust-lang/rust/issues/61956
526            Ok(unsafe { std::mem::transmute_copy(&arr) })
527        }
528
529        array_from_iterator(SeqIter::new(seq), &self)
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn struct_with_option_string() {
539        #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
540        struct Test {
541            #[serde(with = "option_string")]
542            value: Option<u64>,
543        }
544
545        for value in [Test { value: None }, Test { value: Some(123) }, Test {
546            value: Some(u64::MAX),
547        }] {
548            let test = serde_json::to_string(&value).unwrap();
549            println!("{test}");
550            let parsed: Test = serde_json::from_str(&test).unwrap();
551            assert_eq!(value, parsed);
552        }
553    }
554}