shiny_configuration/value/
deserializer.rs

1use serde::de::value::MapAccessDeserializer;
2use serde::de::{DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor};
3use serde::Deserializer;
4use std::collections::hash_map::Iter;
5use std::collections::HashMap;
6use std::borrow::Cow;
7use std::error::Error;
8use std::fmt;
9use std::fmt::Display;
10use super::{Number, Value};
11
12pub struct ValueDeserializer<'a>(pub &'a Value);
13
14impl<'de> Deserializer<'de> for Number {
15    type Error = DeserializationError;
16
17    fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
18        where
19            V: Visitor<'de>,
20    {
21        match self {
22            Number::Integer(value) => v.visit_i64(value),
23            Number::Float(value) => v.visit_f64(value),
24            Number::UInteger(value) => v.visit_u64(value),
25        }
26    }
27
28    serde::forward_to_deserialize_any! {
29        bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq enum
30        bytes byte_buf map struct unit newtype_struct
31        ignored_any unit_struct tuple_struct tuple option identifier
32    }
33}
34
35impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
36    type Error = DeserializationError;
37
38    fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
39        where
40            V: Visitor<'de>,
41    {
42        match self.0 {
43            Value::String(ref s) => v.visit_str(s),
44            Value::Bool(b) => v.visit_bool(*b),
45            Value::Number(n) => n.deserialize_any(v),
46            Value::None => v.visit_none(),
47            Value::Map(map) => v.visit_map(MapDeserializer::new(map)),
48            Value::Array(seq) => v.visit_seq(SequenceDeserializer::new(seq)),
49        }
50    }
51
52    serde::forward_to_deserialize_any! {
53        bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str
54        string seq bytes byte_buf map struct
55        ignored_any tuple_struct tuple identifier
56    }
57
58    fn deserialize_option<V>(self, v: V) -> Result<V::Value, Self::Error>
59        where
60            V: Visitor<'de>,
61    {
62        match self.0 {
63            Value::None => v.visit_none(),
64            _ => v.visit_some(self),
65        }
66    }
67
68    fn deserialize_unit<V>(self, v: V) -> Result<V::Value, Self::Error>
69        where
70            V: Visitor<'de>,
71    {
72        match self.0 {
73            Value::None => v.visit_unit(),
74            _ => self.deserialize_any(v),
75        }
76    }
77
78    fn deserialize_unit_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
79        where
80            V: Visitor<'de>,
81    {
82        match self.0 {
83            Value::None => v.visit_unit(),
84            _ => self.deserialize_any(v),
85        }
86    }
87
88    fn deserialize_newtype_struct<V>(
89        self,
90        _name: &'static str,
91        v: V,
92    ) -> Result<V::Value, Self::Error>
93        where
94            V: Visitor<'de>,
95    {
96        v.visit_newtype_struct(self)
97    }
98
99    fn deserialize_enum<V>(
100        self,
101        _name: &'static str,
102        _variants: &'static [&'static str],
103        v: V,
104    ) -> Result<V::Value, Self::Error>
105        where
106            V: Visitor<'de>,
107    {
108        match self.0 {
109            // Unit variant
110            Value::String(s) => v.visit_enum((**s).into_deserializer()),
111            // Newtype variant, tuple variant, or struct variant
112            Value::Map(ref map) => {
113                let map_access = MapDeserializer::new(map);
114                v.visit_enum(MapAccessDeserializer::new(map_access))
115            }
116            // Everything else
117            _ => self.deserialize_any(v),
118        }
119    }
120
121    fn is_human_readable(&self) -> bool {
122        false
123    }
124}
125
126struct MapDeserializer<'de> {
127    iter: Iter<'de, String, Value>,
128    last_kv_pair: Option<(&'de String, &'de Value)>,
129}
130
131impl<'de> MapDeserializer<'de> {
132    fn new(map: &'de HashMap<String, Value>) -> Self {
133        MapDeserializer {
134            iter: map.iter(),
135            last_kv_pair: None,
136        }
137    }
138}
139
140impl<'de> MapAccess<'de> for MapDeserializer<'de> {
141    type Error = DeserializationError;
142
143    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, DeserializationError>
144        where
145            K: DeserializeSeed<'de>,
146    {
147        if let Some((k, v)) = self.iter.next() {
148            let result = seed.deserialize(k.as_str().into_deserializer()).map(Some);
149
150            self.last_kv_pair = Some((k, v));
151            result.map_err(|err: DeserializationError| err.with_prefix(k))
152        } else {
153            Ok(None)
154        }
155    }
156
157    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, DeserializationError>
158        where
159            V: DeserializeSeed<'de>,
160    {
161        let (key, value) = self
162            .last_kv_pair
163            .take()
164            .expect("visit_value called before visit_key");
165
166        seed.deserialize(ValueDeserializer(value))
167            .map_err(|err: DeserializationError| err.with_prefix(key))
168    }
169}
170
171struct SequenceDeserializer<'de> {
172    iter: std::iter::Enumerate<std::slice::Iter<'de, Value>>,
173    len: usize,
174}
175
176impl<'de> SequenceDeserializer<'de> {
177    fn new(vec: &'de [Value]) -> Self {
178        SequenceDeserializer {
179            iter: vec.iter().enumerate(),
180            len: vec.len(),
181        }
182    }
183}
184
185impl<'de> SeqAccess<'de> for SequenceDeserializer<'de> {
186    type Error = DeserializationError;
187
188    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
189        where
190            T: DeserializeSeed<'de>,
191    {
192        if let Some((i, item)) = self.iter.next() {
193            self.len -= 1;
194            seed.deserialize(ValueDeserializer(item))
195                .map(Some)
196                .map_err(|e: DeserializationError| e.with_prefix(&i.to_string()))
197        } else {
198            Ok(None)
199        }
200    }
201
202    fn size_hint(&self) -> Option<usize> {
203        Some(self.len)
204    }
205}
206
207#[derive(Debug)]
208pub struct DeserializationError {
209    kind: ErrorKind,
210    key: Vec<String>,
211}
212
213impl DeserializationError {
214    pub fn with_prefix(mut self, prefix: &str) -> Self {
215        self.key.push(prefix.to_string());
216        self
217    }
218}
219
220impl Display for DeserializationError {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        let key = self
223            .key
224            .iter()
225            .rev()
226            .map(|x| x.as_ref())
227            .collect::<Vec<&str>>()
228            .join(".");
229        write!(f, "{}, key = `{}`", self.kind, key)
230    }
231}
232
233impl Error for DeserializationError {}
234
235#[derive(Debug)]
236enum ErrorKind {
237    Message(String),
238
239    /// An invalid type: (actual, expected). See [`serde::de::Error::invalid_type()`].
240    InvalidType(UnexpectedOwned, String),
241
242    /// An invalid value: (actual, expected). See [`serde::de::Error::invalid_value()`].
243    InvalidValue(UnexpectedOwned, String),
244
245    /// Too many or too few items: (actual, expected). See [`serde::de::Error::invalid_length()`].
246    InvalidLength(usize, String),
247
248    /// A variant with an unrecognized name: (actual, expected). See [`serde::de::Error::unknown_variant()`].
249    UnknownVariant(String, &'static [&'static str]),
250
251    /// A field with an unrecognized name: (actual, expected). See [`serde::de::Error::unknown_field()`].
252    UnknownField(String, &'static [&'static str]),
253
254    /// A field was missing: (name). See [`serde::de::Error::missing_field()`].
255    MissingField(Cow<'static, str>),
256
257    /// A field appeared more than once: (name). See [`serde::de::Error::duplicate_field()`].
258    DuplicateField(&'static str),
259}
260
261impl Display for ErrorKind {
262    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
263        match self {
264            ErrorKind::Message(msg) => f.write_str(msg),
265            ErrorKind::InvalidType(v, exp) => {
266                write!(f, "Invalid type `{}`, expected `{}`", v, exp)
267            }
268            ErrorKind::InvalidValue(v, exp) => {
269                write!(f, "Invalid value `{}`, expected `{}`", v, exp)
270            }
271            ErrorKind::InvalidLength(v, exp) => {
272                write!(f, "Invalid length `{}`, expected `{}`", v, exp)
273            }
274            ErrorKind::UnknownVariant(v, exp) => {
275                write!(
276                    f,
277                    "Unknown variant `{}`, expected `{}`",
278                    v,
279                    OneOfDisplayWrapper(exp)
280                )
281            }
282            ErrorKind::UnknownField(v, exp) => {
283                write!(
284                    f,
285                    "Unknown field `{}`, expected `{}`",
286                    v,
287                    OneOfDisplayWrapper(exp)
288                )
289            }
290            ErrorKind::MissingField(v) => {
291                write!(f, "Missing field `{}`", v)
292            }
293            ErrorKind::DuplicateField(v) => {
294                write!(f, "Duplicate field `{}`", v)
295            }
296        }
297    }
298}
299
300impl serde::de::Error for DeserializationError {
301    fn custom<T>(msg: T) -> Self
302        where
303            T: Display,
304    {
305        DeserializationError {
306            kind: ErrorKind::Message(msg.to_string()),
307            key: Vec::new(),
308        }
309    }
310
311    fn invalid_type(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
312        DeserializationError {
313            kind: ErrorKind::InvalidType(unexp.into(), exp.to_string()),
314            key: Vec::new(),
315        }
316    }
317
318    fn invalid_value(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
319        DeserializationError {
320            kind: ErrorKind::InvalidValue(unexp.into(), exp.to_string()),
321            key: Vec::new(),
322        }
323    }
324
325    fn invalid_length(len: usize, exp: &dyn serde::de::Expected) -> Self {
326        DeserializationError {
327            kind: ErrorKind::InvalidLength(len, exp.to_string()),
328            key: Vec::new(),
329        }
330    }
331
332    fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
333        DeserializationError {
334            kind: ErrorKind::UnknownVariant(variant.into(), expected),
335            key: Vec::new(),
336        }
337    }
338
339    fn unknown_field(field: &str, expected: &'static [&'static str]) -> Self {
340        DeserializationError {
341            kind: ErrorKind::UnknownField(field.into(), expected),
342            key: Vec::new(),
343        }
344    }
345
346    fn missing_field(field: &'static str) -> Self {
347        DeserializationError {
348            kind: ErrorKind::MissingField(field.into()),
349            key: Vec::new(),
350        }
351    }
352
353    fn duplicate_field(field: &'static str) -> Self {
354        DeserializationError {
355            kind: ErrorKind::DuplicateField(field),
356            key: Vec::new(),
357        }
358    }
359}
360
361// Owned version of serde::de::Unexpected
362#[derive(Clone, Debug, PartialEq)]
363pub enum UnexpectedOwned {
364    Bool(bool),
365    Unsigned(u128),
366    Signed(i128),
367    Float(f64),
368    Char(char),
369    Str(String),
370    Bytes(Vec<u8>),
371    Unit,
372    Option,
373    NewtypeStruct,
374    Seq,
375    Map,
376    Enum,
377    UnitVariant,
378    NewtypeVariant,
379    TupleVariant,
380    StructVariant,
381    Other(String),
382}
383
384impl fmt::Display for UnexpectedOwned {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        match self {
387            UnexpectedOwned::Bool(v) => write!(f, "bool {}", v),
388            UnexpectedOwned::Unsigned(v) => write!(f, "unsigned int `{}`", v),
389            UnexpectedOwned::Signed(v) => write!(f, "signed int `{}`", v),
390            UnexpectedOwned::Float(v) => write!(f, "float `{}`", v),
391            UnexpectedOwned::Char(v) => write!(f, "char {:?}", v),
392            UnexpectedOwned::Str(v) => write!(f, "string {:?}", v),
393            UnexpectedOwned::Bytes(v) => write!(f, "bytes {:?}", v),
394            UnexpectedOwned::Unit => write!(f, "unit"),
395            UnexpectedOwned::Option => write!(f, "option"),
396            UnexpectedOwned::NewtypeStruct => write!(f, "new-type struct"),
397            UnexpectedOwned::Seq => write!(f, "sequence"),
398            UnexpectedOwned::Map => write!(f, "map"),
399            UnexpectedOwned::Enum => write!(f, "enum"),
400            UnexpectedOwned::UnitVariant => write!(f, "unit variant"),
401            UnexpectedOwned::NewtypeVariant => write!(f, "new-type variant"),
402            UnexpectedOwned::TupleVariant => write!(f, "tuple variant"),
403            UnexpectedOwned::StructVariant => write!(f, "struct variant"),
404            UnexpectedOwned::Other(v) => v.fmt(f),
405        }
406    }
407}
408
409impl From<serde::de::Unexpected<'_>> for UnexpectedOwned {
410    fn from(value: serde::de::Unexpected<'_>) -> UnexpectedOwned {
411        match value {
412            serde::de::Unexpected::Bool(v) => UnexpectedOwned::Bool(v),
413            serde::de::Unexpected::Unsigned(v) => UnexpectedOwned::Unsigned(v as u128),
414            serde::de::Unexpected::Signed(v) => UnexpectedOwned::Signed(v as i128),
415            serde::de::Unexpected::Float(v) => UnexpectedOwned::Float(v),
416            serde::de::Unexpected::Char(v) => UnexpectedOwned::Char(v),
417            serde::de::Unexpected::Str(v) => UnexpectedOwned::Str(v.into()),
418            serde::de::Unexpected::Bytes(v) => UnexpectedOwned::Bytes(v.into()),
419            serde::de::Unexpected::Unit => UnexpectedOwned::Unit,
420            serde::de::Unexpected::Option => UnexpectedOwned::Option,
421            serde::de::Unexpected::NewtypeStruct => UnexpectedOwned::NewtypeStruct,
422            serde::de::Unexpected::Seq => UnexpectedOwned::Seq,
423            serde::de::Unexpected::Map => UnexpectedOwned::Map,
424            serde::de::Unexpected::Enum => UnexpectedOwned::Enum,
425            serde::de::Unexpected::UnitVariant => UnexpectedOwned::UnitVariant,
426            serde::de::Unexpected::NewtypeVariant => UnexpectedOwned::NewtypeVariant,
427            serde::de::Unexpected::TupleVariant => UnexpectedOwned::TupleVariant,
428            serde::de::Unexpected::StructVariant => UnexpectedOwned::StructVariant,
429            serde::de::Unexpected::Other(v) => UnexpectedOwned::Other(v.into()),
430        }
431    }
432}
433
434struct OneOfDisplayWrapper(pub &'static [&'static str]);
435
436impl Display for OneOfDisplayWrapper {
437    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
438        match self.0.len() {
439            0 => write!(f, "none"),
440            1 => write!(f, "`{}`", self.0[0]),
441            2 => write!(f, "`{}` or `{}`", self.0[0], self.0[1]),
442            _ => {
443                write!(f, "one of ")?;
444                for (i, alt) in self.0.iter().enumerate() {
445                    if i > 0 {
446                        write!(f, ", ")?;
447                    }
448                    write!(f, "`{}`", alt)?;
449                }
450
451                Ok(())
452            }
453        }
454    }
455}
456
457
458#[cfg(test)]
459mod tests {
460    use std::collections::HashMap;
461
462    use crate::value::{Number, Value};
463    use serde::Deserialize;
464
465    use super::ValueDeserializer;
466
467    #[test]
468    fn test_deserialize_none() {
469        let res = <()>::deserialize(ValueDeserializer(&Value::None)).unwrap();
470        assert_eq!(res, ());
471
472        let res =
473            <Option<i32>>::deserialize(ValueDeserializer(&Value::None)).unwrap();
474        assert_eq!(res, None);
475    }
476
477    #[test]
478    fn test_deserialize_int() {
479        let res = i64::deserialize(ValueDeserializer(&Value::Number(
480            (42 as i64).into(),
481        )))
482            .unwrap();
483        assert_eq!(res, 42);
484
485        let res = u64::deserialize(ValueDeserializer(&Value::Number(
486            (42 as i64).into(),
487        )))
488            .unwrap();
489        assert_eq!(res, 42);
490    }
491
492    #[test]
493    fn test_deserialize_uint() {
494        let res = u64::deserialize(ValueDeserializer(&Value::Number(
495            (42 as u64).into(),
496        )))
497            .unwrap();
498        assert_eq!(res, 42);
499
500        let res = i64::deserialize(ValueDeserializer(&Value::Number(
501            (42 as u64).into(),
502        )))
503            .unwrap();
504        assert_eq!(res, 42);
505    }
506
507    #[test]
508    fn test_deserialize_float() {
509        let res = f64::deserialize(ValueDeserializer(&Value::Number(
510            (42.1 as f64).into(),
511        )))
512            .unwrap();
513
514        assert_eq!(res, 42.1);
515    }
516
517    #[test]
518    fn test_deserialize_string() {
519        let res = String::deserialize(ValueDeserializer(&Value::String(
520            "hello world".to_string(),
521        )))
522            .unwrap();
523
524        assert_eq!(res, "hello world");
525    }
526
527    #[test]
528    fn test_deserialize_bool() {
529        let res = bool::deserialize(ValueDeserializer(&Value::Bool(true))).unwrap();
530        assert!(res);
531
532        let res =
533            bool::deserialize(ValueDeserializer(&Value::Bool(false))).unwrap();
534        assert!(!res);
535    }
536
537    #[test]
538    fn test_deserialize_map() {
539        let value = Value::Map(HashMap::from([
540            (
541                "hello".to_string(),
542                Value::String("world".to_string()),
543            ),
544            (
545                "world".to_string(),
546                Value::String("hello".to_string()),
547            ),
548        ]));
549
550        let res =
551            HashMap::<String, String>::deserialize(ValueDeserializer(&value)).unwrap();
552        assert_eq!(res.get("hello").unwrap(), "world");
553
554        assert_eq!(res.get("world").unwrap(), "hello");
555    }
556
557    #[test]
558    fn test_deserialize_array() {
559        let value = Value::Array(vec![
560            Value::String("hello".to_string()),
561            Value::String("world".to_string()),
562        ]);
563
564        let res = Vec::<String>::deserialize(ValueDeserializer(&value)).unwrap();
565
566        assert_eq!(res.get(0).unwrap(), "hello");
567        assert_eq!(res.get(1).unwrap(), "world");
568    }
569
570    #[test]
571    fn test_deserialize_struct() {
572        #[derive(Deserialize)]
573        struct TestStruct {
574            pub string: String,
575            pub int: i64,
576            pub optional: Option<i32>,
577            pub optional_missing: Option<i32>,
578            pub optional_present: Option<i32>,
579            pub unit: (),
580        }
581
582        let value = Value::Map(HashMap::from([
583            (
584                "string".to_string(),
585                Value::String("Hello World".to_string()),
586            ),
587            (
588                "int".to_string(),
589                Value::Number(Number::UInteger(42)),
590            ),
591            ("optional".to_string(), Value::None),
592            (
593                "optional_present".to_string(),
594                Value::Number(42.into()),
595            ),
596            ("unit".to_string(), Value::None),
597        ]));
598
599        let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
600
601        assert_eq!(res.string, "Hello World");
602        assert_eq!(res.int, 42);
603        assert_eq!(res.optional, None);
604        assert_eq!(res.optional_missing, None);
605        assert_eq!(res.optional_present, Some(42));
606        assert_eq!(res.unit, ());
607    }
608
609    #[test]
610    fn test_deserialize_unit_struct() {
611        #[derive(Deserialize, PartialEq, Eq, Debug)]
612        struct TestStruct;
613
614        let value = Value::None;
615
616        let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
617
618        assert_eq!(res, TestStruct);
619    }
620
621    #[test]
622    fn test_deserialize_newtype_struct() {
623        #[derive(Deserialize, PartialEq, Eq, Debug)]
624        struct TestStruct(String);
625
626        let value = Value::String("Hello World".to_string());
627
628        let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
629
630        assert_eq!(res.0, "Hello World");
631    }
632
633    #[test]
634    fn test_deserialize_enum() {
635        #[derive(Deserialize, PartialEq, Eq, Debug)]
636        enum TestEnum {
637            Unit,
638            NewType(String),
639            Complex { value: String, id: i32 },
640        }
641
642        let value = Value::Map(HashMap::from([("Unit".to_string(), Value::None)]));
643
644        let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
645        assert_eq!(res, TestEnum::Unit);
646
647        let value = Value::String("Unit".to_string());
648        let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
649        assert_eq!(res, TestEnum::Unit);
650
651        let value = Value::Map(HashMap::from([(
652            "NewType".to_string(),
653            Value::String("Hello World".to_string()),
654        )]));
655
656        let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
657        assert_eq!(res, TestEnum::NewType("Hello World".to_string()));
658
659        let value = Value::Map(HashMap::from([(
660            "Complex".to_string(),
661            Value::Map(HashMap::from([
662                (
663                    "value".to_string(),
664                    Value::String("Hello World".to_string()),
665                ),
666                (
667                    "id".to_string(),
668                    Value::Number(Number::UInteger(42)),
669                ),
670            ])),
671        )]));
672
673        let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
674        assert_eq!(
675            res,
676            TestEnum::Complex {
677                value: "Hello World".to_string(),
678                id: 42,
679            }
680        );
681    }
682
683    #[test]
684    fn test_deserialize_error_invalid_type() {
685        #[derive(Deserialize, Debug)]
686        struct TestStruct {
687            #[serde(rename = "string")]
688            pub _string: i32,
689        }
690
691        let value = Value::Map(HashMap::from([(
692            "string".to_string(),
693            Value::String("Hello World".to_string()),
694        )]));
695
696        let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap_err();
697
698        assert_eq!(
699            res.to_string(),
700            "Invalid type `string \"Hello World\"`, expected `i32`, key = `string`"
701        );
702    }
703}