rust_yaml/serde_integration/
de.rs1use crate::{Error, Value, Yaml};
7use serde::de::{
8 self, DeserializeOwned, DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor,
9};
10use std::io::Read;
11
12pub struct ValueDeserializer<'a> {
14 value: &'a Value,
15}
16
17impl<'a> ValueDeserializer<'a> {
18 #[must_use]
20 pub fn new(value: &'a Value) -> Self {
21 Self { value }
22 }
23}
24
25impl<'de, 'a> Deserializer<'de> for ValueDeserializer<'a> {
26 type Error = Error;
27
28 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
29 match self.value {
30 Value::Null => visitor.visit_unit(),
31 Value::Bool(b) => visitor.visit_bool(*b),
32 Value::Int(i) => visitor.visit_i64(*i),
33 Value::Float(f) => visitor.visit_f64(*f),
34 Value::String(s) => visitor.visit_str(s),
35 Value::Sequence(seq) => visitor.visit_seq(SeqAccessImpl { iter: seq.iter() }),
36 Value::Mapping(map) => visitor.visit_map(MapAccessImpl {
37 iter: map.iter(),
38 next_value: None,
39 }),
40 }
41 }
42
43 serde::forward_to_deserialize_any! {
44 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
45 bytes byte_buf unit unit_struct newtype_struct seq tuple tuple_struct
46 map struct identifier ignored_any
47 }
48
49 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
50 match self.value {
51 Value::Null => visitor.visit_none(),
52 _ => visitor.visit_some(self),
53 }
54 }
55
56 fn deserialize_enum<V: Visitor<'de>>(
57 self,
58 _name: &'static str,
59 _variants: &'static [&'static str],
60 visitor: V,
61 ) -> Result<V::Value, Error> {
62 match self.value {
63 Value::String(s) => {
65 let de = de::value::StrDeserializer::<Error>::new(s.as_str());
66 visitor.visit_enum(de)
67 }
68
69 Value::Mapping(map) if map.len() == 1 => {
72 let (k, v) = if let Some(entry) = map.iter().next() {
73 entry
74 } else {
75 return Err(<Error as de::Error>::custom(
76 "internal: len==1 but no entry",
77 ));
78 };
79 let name = match k {
80 Value::String(s) => s.as_str(),
81 _ => {
82 return Err(<Error as de::Error>::custom(
83 "enum variant key must be a string",
84 ));
85 }
86 };
87 visitor.visit_enum(EnumAccessImpl {
88 variant: name,
89 value: v,
90 })
91 }
92
93 other => Err(<Error as de::Error>::custom(format!(
94 "expected enum (string or single-entry mapping), got {other:?}"
95 ))),
96 }
97 }
98}
99
100struct SeqAccessImpl<'a> {
101 iter: std::slice::Iter<'a, Value>,
102}
103
104impl<'de, 'a> SeqAccess<'de> for SeqAccessImpl<'a> {
105 type Error = Error;
106
107 fn next_element_seed<T: DeserializeSeed<'de>>(
108 &mut self,
109 seed: T,
110 ) -> Result<Option<T::Value>, Error> {
111 match self.iter.next() {
112 Some(v) => seed.deserialize(ValueDeserializer::new(v)).map(Some),
113 None => Ok(None),
114 }
115 }
116
117 fn size_hint(&self) -> Option<usize> {
118 Some(self.iter.len())
119 }
120}
121
122struct MapAccessImpl<'a> {
123 iter: indexmap::map::Iter<'a, Value, Value>,
124 next_value: Option<&'a Value>,
125}
126
127impl<'de, 'a> MapAccess<'de> for MapAccessImpl<'a> {
128 type Error = Error;
129
130 fn next_key_seed<K: DeserializeSeed<'de>>(
131 &mut self,
132 seed: K,
133 ) -> Result<Option<K::Value>, Error> {
134 match self.iter.next() {
135 Some((k, v)) => {
136 self.next_value = Some(v);
137 seed.deserialize(ValueDeserializer::new(k)).map(Some)
138 }
139 None => Ok(None),
140 }
141 }
142
143 fn next_value_seed<V: DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value, Error> {
144 let v = self
145 .next_value
146 .take()
147 .ok_or_else(|| <Error as de::Error>::custom("next_value before next_key"))?;
148 seed.deserialize(ValueDeserializer::new(v))
149 }
150
151 fn size_hint(&self) -> Option<usize> {
152 Some(self.iter.len())
153 }
154}
155
156struct EnumAccessImpl<'a> {
157 variant: &'a str,
158 value: &'a Value,
159}
160
161impl<'de, 'a> serde::de::EnumAccess<'de> for EnumAccessImpl<'a> {
162 type Error = Error;
163 type Variant = VariantAccessImpl<'a>;
164
165 fn variant_seed<V: DeserializeSeed<'de>>(
166 self,
167 seed: V,
168 ) -> Result<(V::Value, Self::Variant), Error> {
169 let de = de::value::StrDeserializer::<Error>::new(self.variant);
170 let name: V::Value = seed.deserialize(de)?;
171 Ok((name, VariantAccessImpl { value: self.value }))
172 }
173}
174
175struct VariantAccessImpl<'a> {
176 value: &'a Value,
177}
178
179impl<'de, 'a> serde::de::VariantAccess<'de> for VariantAccessImpl<'a> {
180 type Error = Error;
181
182 fn unit_variant(self) -> Result<(), Error> {
183 match self.value {
184 Value::Null => Ok(()),
185 _ => Err(<Error as de::Error>::custom(
186 "unit variant must have Null payload",
187 )),
188 }
189 }
190
191 fn newtype_variant_seed<T: DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value, Error> {
192 seed.deserialize(ValueDeserializer::new(self.value))
193 }
194
195 fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value, Error> {
196 ValueDeserializer::new(self.value).deserialize_seq(visitor)
197 }
198
199 fn struct_variant<V: Visitor<'de>>(
200 self,
201 _fields: &'static [&'static str],
202 visitor: V,
203 ) -> Result<V::Value, Error> {
204 ValueDeserializer::new(self.value).deserialize_map(visitor)
205 }
206}
207
208pub fn from_str<T: DeserializeOwned>(s: &str) -> Result<T, Error> {
215 let value = Yaml::new().load_str(s)?;
216 T::deserialize(ValueDeserializer::new(&value))
217}
218
219pub fn from_slice<T: DeserializeOwned>(b: &[u8]) -> Result<T, Error> {
226 let s = std::str::from_utf8(b).map_err(Error::from)?;
227 from_str(s)
228}
229
230pub fn from_reader<R: Read, T: DeserializeOwned>(mut r: R) -> Result<T, Error> {
237 let mut buf = String::new();
238 r.read_to_string(&mut buf).map_err(Error::from)?;
239 from_str(&buf)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 #[allow(clippy::float_cmp)]
248 fn from_str_parses_primitives() {
249 assert!(from_str::<bool>("true").unwrap());
250 assert_eq!(from_str::<i64>("42").unwrap(), 42i64);
251 assert_eq!(from_str::<f64>("1.5").unwrap(), 1.5f64);
252 assert_eq!(from_str::<String>("hello").unwrap(), "hello".to_string());
253 assert_eq!(from_str::<Option<i32>>("null").unwrap(), None);
254 assert_eq!(from_str::<Option<i32>>("7").unwrap(), Some(7));
255 }
256
257 #[test]
258 fn vec_of_int_round_trips() {
259 let v: Vec<i32> = from_str("- 1\n- 2\n- 3\n").unwrap();
260 assert_eq!(v, vec![1, 2, 3]);
261 }
262
263 #[test]
264 fn nested_seq_round_trips() {
265 let v: Vec<Vec<i32>> = from_str("- - 1\n - 2\n- - 3\n - 4\n").unwrap();
266 assert_eq!(v, vec![vec![1, 2], vec![3, 4]]);
267 }
268
269 #[test]
270 fn struct_round_trips_through_from_str() {
271 #[derive(serde::Deserialize, Debug, PartialEq)]
272 struct Cfg {
273 name: String,
274 version: u32,
275 enabled: bool,
276 }
277 let cfg: Cfg = from_str("name: rust\nversion: 11\nenabled: true\n").unwrap();
278 assert_eq!(
279 cfg,
280 Cfg {
281 name: "rust".into(),
282 version: 11,
283 enabled: true
284 }
285 );
286 }
287
288 #[test]
289 fn hashmap_round_trips_through_from_str() {
290 use std::collections::HashMap;
291 let m: HashMap<String, i32> = from_str("a: 1\nb: 2\n").unwrap();
292 assert_eq!(m.get("a"), Some(&1));
293 assert_eq!(m.get("b"), Some(&2));
294 }
295
296 #[test]
297 fn from_slice_and_from_reader_match_from_str() {
298 let input = "name: rust\nversion: 11\n";
299 let bytes = input.as_bytes();
300 let from_s: indexmap::IndexMap<String, serde_yaml::Value> = from_str(input).unwrap();
301 let from_b: indexmap::IndexMap<String, serde_yaml::Value> = from_slice(bytes).unwrap();
302 let from_r: indexmap::IndexMap<String, serde_yaml::Value> =
303 from_reader(std::io::Cursor::new(input)).unwrap();
304 assert_eq!(from_s, from_b);
305 assert_eq!(from_s, from_r);
306 }
307
308 #[test]
309 fn unit_variant_deserializes_from_string() {
310 #[derive(serde::Deserialize, Debug, PartialEq)]
311 enum Color {
312 Red,
313 Green,
314 Blue,
315 }
316 let c: Color = from_str("Red").unwrap();
317 assert_eq!(c, Color::Red);
318 }
319
320 #[test]
321 #[allow(clippy::float_cmp)]
322 fn tuple_variant_deserializes_from_tagged_map() {
323 #[derive(serde::Deserialize, Debug, PartialEq)]
324 enum Shape {
325 Circle(f64),
326 Rect(f64, f64),
327 }
328 let c: Shape = from_str("Circle: 1.5\n").unwrap();
329 assert_eq!(c, Shape::Circle(1.5));
330 let r: Shape = from_str("Rect:\n - 2.0\n - 3.0\n").unwrap();
331 assert_eq!(r, Shape::Rect(2.0, 3.0));
332 }
333
334 #[test]
335 fn struct_variant_deserializes_from_tagged_map() {
336 #[derive(serde::Deserialize, Debug, PartialEq)]
337 enum Msg {
338 Point { x: i32, y: i32 },
339 }
340 let p: Msg = from_str("Point:\n x: 1\n y: 2\n").unwrap();
341 assert_eq!(p, Msg::Point { x: 1, y: 2 });
342 }
343}