Skip to main content

rust_yaml/serde_integration/
de.rs

1//! serde::Deserializer that drives serde from a [`Value`].
2//!
3//! `from_str` / `from_slice` / `from_reader` parse YAML to `Value` via the
4//! existing `load_str` pipeline, then walk the tree with this deserializer.
5
6use crate::{Error, Value, Yaml};
7use serde::de::{
8    self, DeserializeOwned, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor,
9};
10use std::io::Read;
11
12/// Borrowed deserializer over an existing `Value`.
13pub struct ValueDeserializer<'a> {
14    value: &'a Value,
15}
16
17impl<'a> ValueDeserializer<'a> {
18    /// Wrap a `Value` reference for deserialization.
19    #[must_use]
20    pub fn new(value: &'a Value) -> Self {
21        Self { value }
22    }
23}
24
25impl<'de, 'a> Deserializer<'de> for ValueDeserializer<'a> {
26    type Error = Error;
27
28    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
29        match self.value {
30            Value::Null => visitor.visit_unit(),
31            Value::Bool(b) => visitor.visit_bool(*b),
32            Value::Int(i) => visitor.visit_i64(*i),
33            Value::Float(f) => visitor.visit_f64(*f),
34            Value::String(s) => visitor.visit_str(s),
35            Value::Sequence(seq) => visitor.visit_seq(SeqAccessImpl { iter: seq.iter() }),
36            Value::Mapping(map) => visitor.visit_map(MapAccessImpl {
37                iter: map.iter(),
38                next_value: None,
39            }),
40        }
41    }
42
43    serde::forward_to_deserialize_any! {
44        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
45        bytes byte_buf unit unit_struct newtype_struct seq tuple tuple_struct
46        map struct identifier ignored_any
47    }
48
49    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
50        match self.value {
51            Value::Null => visitor.visit_none(),
52            _ => visitor.visit_some(self),
53        }
54    }
55
56    fn deserialize_enum<V: Visitor<'de>>(
57        self,
58        _name: &'static str,
59        _variants: &'static [&'static str],
60        visitor: V,
61    ) -> Result<V::Value, Error> {
62        match self.value {
63            // Unit variant: a bare string naming the variant.
64            Value::String(s) => {
65                let de = de::value::StrDeserializer::<Error>::new(s.as_str());
66                visitor.visit_enum(de)
67            }
68
69            // Tuple / struct / newtype variant: a single-entry mapping whose key
70            // is the variant name and whose value carries the payload.
71            Value::Mapping(map) if map.len() == 1 => {
72                let (k, v) = if let Some(entry) = map.iter().next() {
73                    entry
74                } else {
75                    return Err(<Error as de::Error>::custom(
76                        "internal: len==1 but no entry",
77                    ));
78                };
79                let name = match k {
80                    Value::String(s) => s.as_str(),
81                    _ => {
82                        return Err(<Error as de::Error>::custom(
83                            "enum variant key must be a string",
84                        ));
85                    }
86                };
87                visitor.visit_enum(EnumAccessImpl {
88                    variant: name,
89                    value: v,
90                })
91            }
92
93            other => Err(<Error as de::Error>::custom(format!(
94                "expected enum (string or single-entry mapping), got {other:?}"
95            ))),
96        }
97    }
98}
99
100struct SeqAccessImpl<'a> {
101    iter: std::slice::Iter<'a, Value>,
102}
103
104impl<'de, 'a> SeqAccess<'de> for SeqAccessImpl<'a> {
105    type Error = Error;
106
107    fn next_element_seed<T: DeserializeSeed<'de>>(
108        &mut self,
109        seed: T,
110    ) -> Result<Option<T::Value>, Error> {
111        match self.iter.next() {
112            Some(v) => seed.deserialize(ValueDeserializer::new(v)).map(Some),
113            None => Ok(None),
114        }
115    }
116
117    fn size_hint(&self) -> Option<usize> {
118        Some(self.iter.len())
119    }
120}
121
122struct MapAccessImpl<'a> {
123    iter: indexmap::map::Iter<'a, Value, Value>,
124    next_value: Option<&'a Value>,
125}
126
127impl<'de, 'a> MapAccess<'de> for MapAccessImpl<'a> {
128    type Error = Error;
129
130    fn next_key_seed<K: DeserializeSeed<'de>>(
131        &mut self,
132        seed: K,
133    ) -> Result<Option<K::Value>, Error> {
134        match self.iter.next() {
135            Some((k, v)) => {
136                self.next_value = Some(v);
137                seed.deserialize(ValueDeserializer::new(k)).map(Some)
138            }
139            None => Ok(None),
140        }
141    }
142
143    fn next_value_seed<V: DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value, Error> {
144        let v = self
145            .next_value
146            .take()
147            .ok_or_else(|| <Error as de::Error>::custom("next_value before next_key"))?;
148        seed.deserialize(ValueDeserializer::new(v))
149    }
150
151    fn size_hint(&self) -> Option<usize> {
152        Some(self.iter.len())
153    }
154}
155
156struct EnumAccessImpl<'a> {
157    variant: &'a str,
158    value: &'a Value,
159}
160
161impl<'de, 'a> serde::de::EnumAccess<'de> for EnumAccessImpl<'a> {
162    type Error = Error;
163    type Variant = VariantAccessImpl<'a>;
164
165    fn variant_seed<V: DeserializeSeed<'de>>(
166        self,
167        seed: V,
168    ) -> Result<(V::Value, Self::Variant), Error> {
169        let de = de::value::StrDeserializer::<Error>::new(self.variant);
170        let name: V::Value = seed.deserialize(de)?;
171        Ok((name, VariantAccessImpl { value: self.value }))
172    }
173}
174
175struct VariantAccessImpl<'a> {
176    value: &'a Value,
177}
178
179impl<'de, 'a> serde::de::VariantAccess<'de> for VariantAccessImpl<'a> {
180    type Error = Error;
181
182    fn unit_variant(self) -> Result<(), Error> {
183        match self.value {
184            Value::Null => Ok(()),
185            _ => Err(<Error as de::Error>::custom(
186                "unit variant must have Null payload",
187            )),
188        }
189    }
190
191    fn newtype_variant_seed<T: DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value, Error> {
192        seed.deserialize(ValueDeserializer::new(self.value))
193    }
194
195    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value, Error> {
196        ValueDeserializer::new(self.value).deserialize_seq(visitor)
197    }
198
199    fn struct_variant<V: Visitor<'de>>(
200        self,
201        _fields: &'static [&'static str],
202        visitor: V,
203    ) -> Result<V::Value, Error> {
204        ValueDeserializer::new(self.value).deserialize_map(visitor)
205    }
206}
207
208/// Parse YAML from a string into `T`.
209///
210/// # Errors
211///
212/// Returns an error if the YAML fails to parse or if `T`'s `Deserialize`
213/// impl rejects the resulting structure.
214pub fn from_str<T: DeserializeOwned>(s: &str) -> Result<T, Error> {
215    let value = Yaml::new().load_str(s)?;
216    T::deserialize(ValueDeserializer::new(&value))
217}
218
219/// Parse YAML from a byte slice into `T`.
220///
221/// # Errors
222///
223/// Returns an error if the bytes are not valid UTF-8, if the YAML fails to
224/// parse, or if `T`'s `Deserialize` impl rejects the resulting structure.
225pub fn from_slice<T: DeserializeOwned>(b: &[u8]) -> Result<T, Error> {
226    let s = std::str::from_utf8(b).map_err(Error::from)?;
227    from_str(s)
228}
229
230/// Parse YAML from a reader into `T`.
231///
232/// # Errors
233///
234/// Returns an error if reading fails, if the YAML fails to parse, or if
235/// `T`'s `Deserialize` impl rejects the resulting structure.
236pub fn from_reader<R: Read, T: DeserializeOwned>(mut r: R) -> Result<T, Error> {
237    let mut buf = String::new();
238    r.read_to_string(&mut buf).map_err(Error::from)?;
239    from_str(&buf)
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    #[allow(clippy::float_cmp)]
248    fn from_str_parses_primitives() {
249        assert!(from_str::<bool>("true").unwrap());
250        assert_eq!(from_str::<i64>("42").unwrap(), 42i64);
251        assert_eq!(from_str::<f64>("1.5").unwrap(), 1.5f64);
252        assert_eq!(from_str::<String>("hello").unwrap(), "hello".to_string());
253        assert_eq!(from_str::<Option<i32>>("null").unwrap(), None);
254        assert_eq!(from_str::<Option<i32>>("7").unwrap(), Some(7));
255    }
256
257    #[test]
258    fn vec_of_int_round_trips() {
259        let v: Vec<i32> = from_str("- 1\n- 2\n- 3\n").unwrap();
260        assert_eq!(v, vec![1, 2, 3]);
261    }
262
263    #[test]
264    fn nested_seq_round_trips() {
265        let v: Vec<Vec<i32>> = from_str("- - 1\n  - 2\n- - 3\n  - 4\n").unwrap();
266        assert_eq!(v, vec![vec![1, 2], vec![3, 4]]);
267    }
268
269    #[test]
270    fn struct_round_trips_through_from_str() {
271        #[derive(serde::Deserialize, Debug, PartialEq)]
272        struct Cfg {
273            name: String,
274            version: u32,
275            enabled: bool,
276        }
277        let cfg: Cfg = from_str("name: rust\nversion: 11\nenabled: true\n").unwrap();
278        assert_eq!(
279            cfg,
280            Cfg {
281                name: "rust".into(),
282                version: 11,
283                enabled: true
284            }
285        );
286    }
287
288    #[test]
289    fn hashmap_round_trips_through_from_str() {
290        use std::collections::HashMap;
291        let m: HashMap<String, i32> = from_str("a: 1\nb: 2\n").unwrap();
292        assert_eq!(m.get("a"), Some(&1));
293        assert_eq!(m.get("b"), Some(&2));
294    }
295
296    #[test]
297    fn from_slice_and_from_reader_match_from_str() {
298        let input = "name: rust\nversion: 11\n";
299        let bytes = input.as_bytes();
300        let from_s: indexmap::IndexMap<String, serde_yaml::Value> = from_str(input).unwrap();
301        let from_b: indexmap::IndexMap<String, serde_yaml::Value> = from_slice(bytes).unwrap();
302        let from_r: indexmap::IndexMap<String, serde_yaml::Value> =
303            from_reader(std::io::Cursor::new(input)).unwrap();
304        assert_eq!(from_s, from_b);
305        assert_eq!(from_s, from_r);
306    }
307
308    #[test]
309    fn unit_variant_deserializes_from_string() {
310        #[derive(serde::Deserialize, Debug, PartialEq)]
311        enum Color {
312            Red,
313            Green,
314            Blue,
315        }
316        let c: Color = from_str("Red").unwrap();
317        assert_eq!(c, Color::Red);
318    }
319
320    #[test]
321    #[allow(clippy::float_cmp)]
322    fn tuple_variant_deserializes_from_tagged_map() {
323        #[derive(serde::Deserialize, Debug, PartialEq)]
324        enum Shape {
325            Circle(f64),
326            Rect(f64, f64),
327        }
328        let c: Shape = from_str("Circle: 1.5\n").unwrap();
329        assert_eq!(c, Shape::Circle(1.5));
330        let r: Shape = from_str("Rect:\n  - 2.0\n  - 3.0\n").unwrap();
331        assert_eq!(r, Shape::Rect(2.0, 3.0));
332    }
333
334    #[test]
335    fn struct_variant_deserializes_from_tagged_map() {
336        #[derive(serde::Deserialize, Debug, PartialEq)]
337        enum Msg {
338            Point { x: i32, y: i32 },
339        }
340        let p: Msg = from_str("Point:\n  x: 1\n  y: 2\n").unwrap();
341        assert_eq!(p, Msg::Point { x: 1, y: 2 });
342    }
343}