scratchstack_aspen/
serutil.rs

1use {
2    log::{debug, error},
3    serde::{
4        de::{
5            self,
6            value::{MapAccessDeserializer, SeqAccessDeserializer},
7            Deserializer, MapAccess, SeqAccess, Unexpected, Visitor,
8        },
9        ser::{SerializeSeq, Serializer},
10        Deserialize, Serialize,
11    },
12    std::{
13        any::type_name,
14        fmt::{Debug, Display, Error as FmtError, Formatter, Result as FmtResult},
15        marker::PhantomData,
16        str::{from_utf8, FromStr},
17    },
18};
19
20/// Return the simplified type name of a type.
21fn simple_type_name<E>() -> &'static str {
22    // Get the type name of the element we're serializing.
23    let tn = type_name::<E>();
24
25    // If it's wrapped in an Option or the like, unwrap it.
26    let tn = match tn.rfind('<') {
27        None => tn,
28        Some(i) => {
29            let sub = &tn[i + 1..tn.len()];
30            match sub.find('>') {
31                None => sub,
32                Some(j) => &sub[..j],
33            }
34        }
35    };
36
37    // If it's a reference, unwrap it.
38    let tn = tn.trim_start_matches('&');
39
40    // If it's a path, use just the last component.
41    match tn.rfind("::") {
42        None => tn,
43        Some(i) => &tn[i + 2..],
44    }
45}
46
47/// Implement Display for a given class by formatting it as pretty-printed JSON.
48#[macro_export]
49macro_rules! display_json {
50    ($cls:ty) => {
51        impl ::std::fmt::Display for $cls {
52            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53                let buf = Vec::new();
54                let serde_formatter = ::serde_json::ser::PrettyFormatter::with_indent(b"    ");
55                let mut ser = ::serde_json::Serializer::with_formatter(buf, serde_formatter);
56                match self.serialize(&mut ser) {
57                    Ok(()) => (),
58                    Err(e) => {
59                        ::log::error!("Failed to serialize: {}", e);
60                        return Err(::std::fmt::Error {});
61                    }
62                };
63                match std::str::from_utf8(&ser.into_inner()) {
64                    Ok(s) => write!(f, "{}", s),
65                    Err(e) => {
66                        ::log::error!("JSON serialization contained non-UTF-8 characters: {}", e);
67                        Err(::std::fmt::Error {})
68                    }
69                }
70            }
71        }
72    };
73}
74
75/// Implement FromStr for a given class by parsing it as JSON.
76#[macro_export]
77macro_rules! from_str_json {
78    ($cls:ident) => {
79        impl ::std::str::FromStr for $cls {
80            type Err = ::serde_json::Error;
81
82            fn from_str(s: &str) -> Result<Self, Self::Err> {
83                match ::serde_json::from_str::<Self>(s) {
84                    Ok(result) => Ok(result),
85                    Err(e) => {
86                        ::log::debug!("Failed to parse: {}: {:?}", s, e);
87                        Err(e)
88                    }
89                }
90            }
91        }
92    };
93}
94
95/// The JSON representation of a list-like type.
96#[derive(Clone, Copy, Debug, Eq, PartialEq)]
97pub enum JsonRep {
98    Single,
99    List,
100}
101
102macro_rules! define_list_like_type {
103    ($list_like_type:ident) => {
104        /// $llt allows a JSON field to be represented as the element itself (equivalent to a list of 1 item) or as s
105        /// list of elements.
106
107        pub struct $list_like_type<E> {
108            elements: ::std::vec::Vec<E>,
109            kind: $crate::serutil::JsonRep,
110        }
111
112        impl<E> $list_like_type<E> {
113            /// Returns the JSON representation of the list.
114            #[inline]
115            pub fn kind(&self) -> $crate::serutil::JsonRep {
116                self.kind
117            }
118
119            /// Returns the elements of the list as a slice.
120            #[inline]
121            pub fn as_slice(&self) -> &[E] {
122                self.elements.as_slice()
123            }
124
125            /// Returns the elements of the list as a vector of references.
126            pub fn to_vec(&self) -> Vec<&E> {
127                let mut result = ::std::vec::Vec::with_capacity(self.elements.len());
128                for element in self.elements.iter() {
129                    result.push(element);
130                }
131                result
132            }
133
134            /// Returns `true` if the list is empty.
135            #[inline]
136            pub fn is_empty(&self) -> bool {
137                self.elements.is_empty()
138            }
139
140            /// Returns the number of elements in the list.
141            #[inline]
142            pub fn len(&self) -> usize {
143                self.elements.len()
144            }
145        }
146
147        impl<E> ::std::clone::Clone for $list_like_type<E>
148        where
149            E: ::std::clone::Clone,
150        {
151            fn clone(&self) -> Self {
152                Self {
153                    elements: self.elements.clone(),
154                    kind: self.kind,
155                }
156            }
157        }
158
159        impl<E> ::std::fmt::Debug for $list_like_type<E>
160        where
161            E: ::std::fmt::Debug,
162        {
163            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
164                match self.kind {
165                    $crate::serutil::JsonRep::Single => (&self.elements[0] as &dyn ::std::fmt::Debug).fmt(f),
166                    $crate::serutil::JsonRep::List => (&self.elements as &dyn ::std::fmt::Debug).fmt(f),
167                }
168            }
169        }
170
171        impl<E> ::std::cmp::PartialEq for $list_like_type<E>
172        where
173            E: ::std::cmp::PartialEq,
174        {
175            fn eq(&self, other: &Self) -> bool {
176                self.elements == other.elements
177            }
178        }
179
180        impl<E> ::std::cmp::Eq for $list_like_type<E> where E: ::std::cmp::Eq {}
181
182        impl<E> ::std::convert::From<E> for $list_like_type<E> {
183            fn from(v: E) -> Self {
184                Self {
185                    elements: vec![v],
186                    kind: $crate::serutil::JsonRep::Single,
187                }
188            }
189        }
190
191        impl<E> ::std::convert::From<Vec<E>> for $list_like_type<E> {
192            fn from(v: ::std::vec::Vec<E>) -> Self {
193                Self {
194                    elements: v,
195                    kind: $crate::serutil::JsonRep::List,
196                }
197            }
198        }
199
200        impl<E, I> ::std::ops::Index<I> for $list_like_type<E>
201        where
202            I: ::std::slice::SliceIndex<[E]>,
203        {
204            type Output = <I as ::std::slice::SliceIndex<[E]>>::Output;
205
206            fn index(&self, index: I) -> &<::std::vec::Vec<E> as ::std::ops::Index<I>>::Output {
207                self.elements.index(index)
208            }
209        }
210
211        impl<E> ::std::ops::Deref for $list_like_type<E> {
212            type Target = [E];
213
214            fn deref(&self) -> &[E] {
215                self.elements.deref()
216            }
217        }
218    };
219}
220
221define_list_like_type!(MapList);
222
223struct MapListVisitor<E> {
224    phantom: PhantomData<E>,
225}
226
227impl<'de, E: Deserialize<'de>> Visitor<'de> for MapListVisitor<E> {
228    type Value = MapList<E>;
229
230    fn expecting(&self, f: &mut Formatter) -> FmtResult {
231        let tn = simple_type_name::<E>();
232        write!(f, "{tn} or list of {tn}")
233    }
234
235    fn visit_map<A: MapAccess<'de>>(self, access: A) -> Result<Self::Value, A::Error> {
236        let el = E::deserialize(MapAccessDeserializer::new(access))?;
237        Ok(MapList {
238            elements: vec![el],
239            kind: JsonRep::Single,
240        })
241    }
242
243    fn visit_seq<A: SeqAccess<'de>>(self, mut access: A) -> Result<Self::Value, A::Error> {
244        let mut result: Vec<E> = match access.size_hint() {
245            None => Vec::new(),
246            Some(size) => Vec::with_capacity(size),
247        };
248
249        while let Some(item) = access.next_element::<E>()? {
250            result.push(item);
251        }
252
253        Ok(MapList {
254            elements: result,
255            kind: JsonRep::List,
256        })
257    }
258}
259
260impl<E: Serialize> Display for MapList<E> {
261    fn fmt(&self, f: &mut Formatter) -> FmtResult {
262        let buf = Vec::new();
263        let serde_formatter = serde_json::ser::PrettyFormatter::with_indent(b"    ");
264        let mut ser = serde_json::Serializer::with_formatter(buf, serde_formatter);
265        match self.serialize(&mut ser) {
266            Ok(()) => (),
267            Err(e) => {
268                error!("Failed to serialize: {}", e);
269                return Err(FmtError {});
270            }
271        };
272        match from_utf8(&ser.into_inner()) {
273            Ok(s) => f.write_str(s),
274            Err(e) => {
275                error!("JSON serialization contained non-UTF-8 characters: {}", e);
276                Err(FmtError {})
277            }
278        }
279    }
280}
281
282impl<'de, E: Deserialize<'de>> Deserialize<'de> for MapList<E> {
283    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
284        deserializer.deserialize_any(MapListVisitor {
285            phantom: PhantomData,
286        })
287    }
288}
289
290impl<E: Serialize> Serialize for MapList<E> {
291    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
292        match self.kind {
293            JsonRep::Single => self.elements[0].serialize(serializer),
294            JsonRep::List => {
295                let mut seq = serializer.serialize_seq(Some(self.elements.len()))?;
296                for e in &self.elements {
297                    seq.serialize_element(e)?;
298                }
299                seq.end()
300            }
301        }
302    }
303}
304
305define_list_like_type!(StringLikeList);
306
307struct StringListVisitor<T> {
308    _phantom: PhantomData<T>,
309}
310
311impl<'de, T> Visitor<'de> for StringListVisitor<T>
312where
313    T: FromStr,
314    <T as FromStr>::Err: Display,
315{
316    type Value = StringLikeList<T>;
317
318    fn expecting(&self, f: &mut Formatter) -> FmtResult {
319        let tn = simple_type_name::<T>();
320        write!(f, "{tn} or list of {tn}")
321    }
322
323    fn visit_seq<A: SeqAccess<'de>>(self, access: A) -> Result<Self::Value, A::Error> {
324        let deserializer = SeqAccessDeserializer::new(access);
325        match Vec::<String>::deserialize(deserializer) {
326            Ok(l) => {
327                let mut result = Vec::with_capacity(l.len());
328                for e in &l {
329                    match T::from_str(e) {
330                        Ok(s) => result.push(s),
331                        Err(e) => return Err(de::Error::custom(e)),
332                    }
333                }
334                Ok(StringLikeList {
335                    elements: result,
336                    kind: JsonRep::List,
337                })
338            }
339            Err(e) => {
340                debug!("Failed to deserialize string list: {:?}", e);
341                Err(<A::Error as de::Error>::invalid_value(Unexpected::Seq, &self))
342            }
343        }
344    }
345
346    fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
347        match T::from_str(v) {
348            Ok(s) => Ok(StringLikeList {
349                elements: vec![s],
350                kind: JsonRep::Single,
351            }),
352            Err(e) => Err(de::Error::custom(e)),
353        }
354    }
355}
356
357impl<E: ToString> Display for StringLikeList<E> {
358    fn fmt(&self, f: &mut Formatter) -> FmtResult {
359        let buf = Vec::new();
360        let serde_formatter = serde_json::ser::PrettyFormatter::with_indent(b"    ");
361        let mut ser = serde_json::Serializer::with_formatter(buf, serde_formatter);
362        match self.serialize(&mut ser) {
363            Ok(()) => (),
364            Err(e) => {
365                error!("Failed to serialize: {}", e);
366                return Err(FmtError {});
367            }
368        };
369        match from_utf8(&ser.into_inner()) {
370            Ok(s) => f.write_str(s),
371            Err(e) => {
372                error!("JSON serialization contained non-UTF-8 characters: {}", e);
373                Err(FmtError {})
374            }
375        }
376    }
377}
378
379impl<'de, T> Deserialize<'de> for StringLikeList<T>
380where
381    T: FromStr,
382    <T as FromStr>::Err: Display,
383{
384    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
385        deserializer.deserialize_any(StringListVisitor {
386            _phantom: PhantomData,
387        })
388    }
389}
390
391impl<T> Serialize for StringLikeList<T>
392where
393    T: ToString,
394{
395    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
396        match self.kind {
397            JsonRep::Single => {
398                let s = self.elements[0].to_string();
399                s.serialize(serializer)
400            }
401            JsonRep::List => {
402                let mut seq = Vec::with_capacity(self.elements.len());
403                for e in &self.elements {
404                    let s = e.to_string();
405                    seq.push(s);
406                }
407                seq.serialize(serializer)
408            }
409        }
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use {
416        super::{simple_type_name, JsonRep, MapList},
417        crate::display_json,
418        indoc::indoc,
419        serde::{ser::Serializer, Deserialize, Serialize},
420        std::panic::catch_unwind,
421    };
422
423    #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
424    struct SimpleMap {
425        pub value: u32,
426    }
427
428    #[allow(clippy::redundant_clone)]
429    #[test_log::test]
430    fn test_basic_ops() {
431        let map1 = SimpleMap {
432            value: 42,
433        };
434
435        let map2 = SimpleMap {
436            value: 43,
437        };
438
439        let el1a: MapList<SimpleMap> = map1.clone().into();
440        let el1b: MapList<SimpleMap> = vec![map1.clone()].into();
441        let el2a: MapList<SimpleMap> = vec![map1.clone(), map2.clone()].into();
442        let el2b: MapList<SimpleMap> = vec![map1.clone(), map2].into();
443        let el3: MapList<SimpleMap> = vec![].into();
444        assert_eq!(el1a, el1b);
445        assert_eq!(el1b, el1a);
446        assert_ne!(el1a, el2a);
447        assert_ne!(el2b, el1a);
448        assert_eq!(el2a, el2b);
449
450        assert!(!el1a.is_empty());
451        assert!(!el1b.is_empty());
452        assert!(!el2a.is_empty());
453        assert!(el3.is_empty());
454
455        assert_eq!(el1a.len(), 1);
456        assert_eq!(el1b.len(), 1);
457        assert_eq!(el2a.len(), 2);
458        assert_eq!(el3.len(), 0);
459
460        assert_eq!(el1a.clone(), el1a);
461
462        assert_eq!(
463            format!("{el1a}"),
464            indoc! { r#"
465            {
466                "value": 42
467            }"#}
468        );
469        assert_eq!(
470            format!("{el1b}"),
471            indoc! { r#"
472            [
473                {
474                    "value": 42
475                }
476            ]"#}
477        );
478        assert_eq!(
479            format!("{el2a}"),
480            indoc! { r#"
481            [
482                {
483                    "value": 42
484                },
485                {
486                    "value": 43
487                }
488            ]"# }
489        );
490
491        assert_eq!(el1a[0].value, 42);
492        assert_eq!(el1b[0].value, 42);
493        let e = catch_unwind(|| {
494            let new_el: MapList<SimpleMap> = map1.clone().into();
495            println!("This won't print: {:?}", &new_el[1]);
496        })
497        .unwrap_err();
498        assert_eq!(*e.downcast::<String>().unwrap(), "index out of bounds: the len is 1 but the index is 1");
499    }
500
501    #[derive(Clone, Debug)]
502    struct SerBadUtf8 {}
503    const BAD_UTF8: [u8; 3] = [0xc3, 0xc3, 0xc3];
504
505    impl Serialize for SerBadUtf8 {
506        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
507            let s = unsafe { String::from_utf8_unchecked(BAD_UTF8.to_vec()) };
508            serializer.serialize_str(&s)
509        }
510    }
511
512    display_json!(SerBadUtf8);
513
514    #[test_log::test]
515    fn test_ser_fail() {
516        let el: MapList<SerBadUtf8> = vec![SerBadUtf8 {}].into();
517        let e = catch_unwind(|| format!("{el}")).unwrap_err();
518        let e2 = e.downcast::<String>().unwrap();
519        assert!((*e2).contains("a formatting trait implementation returned an error"));
520
521        let e = catch_unwind(|| format!("{el}")).unwrap_err();
522        let e2 = e.downcast::<String>().unwrap();
523        assert!((*e2).contains("a formatting trait implementation returned an error"));
524    }
525
526    #[test_log::test]
527    fn test_simple_type_name() {
528        assert_eq!(simple_type_name::<u32>(), "u32");
529        assert_eq!(simple_type_name::<Option<u32>>(), "u32");
530    }
531
532    #[test_log::test]
533    fn test_list_kind() {
534        assert_eq!(JsonRep::Single, JsonRep::Single.clone());
535        assert_eq!(JsonRep::List, JsonRep::List.clone());
536        assert_ne!(JsonRep::Single, JsonRep::List);
537        assert_eq!(format!("{:?}", JsonRep::Single), "Single");
538        assert_eq!(format!("{:?}", JsonRep::List), "List");
539    }
540}