1use std::{io::Read, marker::PhantomData};
2
3#[cfg(feature = "log")]
4use log::trace;
5use serde::de::{self, Unexpected};
6use serde::forward_to_deserialize_any;
7use xml::name::OwnedName;
8use xml::reader::{EventReader, ParserConfig, XmlEvent};
9
10use self::buffer::{BufferedXmlReader, ChildXmlBuffer, RootXmlBuffer};
11use self::map::MapAccess;
12use self::seq::SeqAccess;
13use self::var::EnumAccess;
14use crate::error::{Error, Result};
15use crate::{debug_expect, expect};
16
17mod buffer;
18mod map;
19mod seq;
20mod var;
21
22pub fn from_str<'de, T: de::Deserialize<'de>>(s: &str) -> Result<T> {
39 from_reader(s.as_bytes())
40}
41
42pub fn from_reader<'de, R: Read, T: de::Deserialize<'de>>(reader: R) -> Result<T> {
59 T::deserialize(&mut Deserializer::new_from_reader(reader))
60}
61
62type RootDeserializer<R> = Deserializer<R, RootXmlBuffer<R>>;
63type ChildDeserializer<'parent, R> = Deserializer<R, ChildXmlBuffer<'parent, R>>;
64
65pub struct Deserializer<
66 R: Read, B: BufferedXmlReader<R> = RootXmlBuffer<R>,
68> {
69 depth: usize,
71 buffered_reader: B,
72 is_map_value: bool,
73 non_contiguous_seq_elements: bool,
74 marker: PhantomData<R>,
75}
76
77impl<'de, R: Read> RootDeserializer<R> {
78 pub fn new(reader: EventReader<R>) -> Self {
79 let buffered_reader = RootXmlBuffer::new(reader);
80
81 Deserializer {
82 buffered_reader,
83 depth: 0,
84 is_map_value: false,
85 non_contiguous_seq_elements: false,
86 marker: PhantomData,
87 }
88 }
89
90 pub fn new_from_reader(reader: R) -> Self {
91 let config = ParserConfig::new()
92 .trim_whitespace(true)
93 .whitespace_to_characters(true)
94 .cdata_to_characters(true)
95 .ignore_comments(true)
96 .coalesce_characters(true);
97
98 Self::new(EventReader::new_with_config(reader, config))
99 }
100
101 pub fn non_contiguous_seq_elements(mut self, set: bool) -> Self {
130 self.non_contiguous_seq_elements = set;
131 self
132 }
133}
134
135impl<'de, R: Read, B: BufferedXmlReader<R>> Deserializer<R, B> {
136 fn child<'a>(&'a mut self) -> Deserializer<R, ChildXmlBuffer<'a, R>> {
137 let Deserializer {
138 buffered_reader,
139 depth,
140 is_map_value,
141 non_contiguous_seq_elements,
142 ..
143 } = self;
144
145 Deserializer {
146 buffered_reader: buffered_reader.child_buffer(),
147 depth: *depth,
148 is_map_value: *is_map_value,
149 non_contiguous_seq_elements: *non_contiguous_seq_elements,
150 marker: PhantomData,
151 }
152 }
153
154 fn peek(&mut self) -> Result<&XmlEvent> {
156 let peeked = self.buffered_reader.peek()?;
157
158 #[cfg(feature = "log")]
159 trace!("Peeked {:?}", peeked);
160 Ok(peeked)
161 }
162
163 fn next(&mut self) -> Result<XmlEvent> {
165 let next = self.buffered_reader.next()?;
166
167 match next {
168 XmlEvent::StartElement { .. } => {
169 self.depth += 1;
170 }
171 XmlEvent::EndElement { .. } => {
172 self.depth -= 1;
173 }
174 _ => {}
175 }
176 #[cfg(feature = "log")]
177 trace!("Fetched {:?}", next);
178 Ok(next)
179 }
180
181 fn set_map_value(&mut self) {
182 self.is_map_value = true;
183 }
184
185 pub fn unset_map_value(&mut self) -> bool {
186 ::std::mem::replace(&mut self.is_map_value, false)
187 }
188
189 fn read_inner_value<V: de::Visitor<'de>, T, F: FnOnce(&mut Self) -> Result<T>>(
194 &mut self,
195 f: F,
196 ) -> Result<T> {
197 if self.unset_map_value() {
198 debug_expect!(self.next(), Ok(XmlEvent::StartElement { name, .. }) => {
199 let result = f(self)?;
200 self.expect_end_element(name)?;
201 Ok(result)
202 })
203 } else {
204 f(self)
205 }
206 }
207
208 fn expect_end_element(&mut self, start_name: OwnedName) -> Result<()> {
209 expect!(self.next()?, XmlEvent::EndElement { name, .. } => {
210 if name == start_name {
211 Ok(())
212 } else {
213 Err(Error::Custom { field: format!(
214 "End tag </{}> didn't match the start tag <{}>",
215 name.local_name,
216 start_name.local_name
217 ) })
218 }
219 })
220 }
221
222 fn prepare_parse_type<V: de::Visitor<'de>>(&mut self) -> Result<String> {
223 if let XmlEvent::StartElement { .. } = *self.peek()? {
224 self.set_map_value()
225 }
226 self.read_inner_value::<V, String, _>(|this| {
227 if let XmlEvent::EndElement { .. } = *this.peek()? {
228 return Err(Error::UnexpectedToken {
229 token: "EndElement".into(),
230 found: "Characters".into(),
231 });
232 }
233
234 expect!(this.next()?, XmlEvent::Characters(s) => {
235 return Ok(s)
236 })
237 })
238 }
239}
240
241macro_rules! deserialize_type {
242 ($deserialize:ident => $visit:ident) => {
243 fn $deserialize<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
244 let value = self.prepare_parse_type::<V>()?.parse()?;
245 visitor.$visit(value)
246 }
247 };
248}
249
250impl<'de, 'a, R: Read, B: BufferedXmlReader<R>> de::Deserializer<'de>
251 for &'a mut Deserializer<R, B>
252{
253 type Error = Error;
254
255 forward_to_deserialize_any! {
256 identifier
257 }
258
259 fn deserialize_struct<V: de::Visitor<'de>>(
260 self,
261 _name: &'static str,
262 fields: &'static [&'static str],
263 visitor: V,
264 ) -> Result<V::Value> {
265 self.unset_map_value();
266 expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
267 let map_value = visitor.visit_map(MapAccess::new(
268 self,
269 attributes,
270 fields.contains(&"$value")
271 ))?;
272 self.expect_end_element(name)?;
273 Ok(map_value)
274 })
275 }
276
277 deserialize_type!(deserialize_i8 => visit_i8);
278 deserialize_type!(deserialize_i16 => visit_i16);
279 deserialize_type!(deserialize_i32 => visit_i32);
280 deserialize_type!(deserialize_i64 => visit_i64);
281 deserialize_type!(deserialize_u8 => visit_u8);
282 deserialize_type!(deserialize_u16 => visit_u16);
283 deserialize_type!(deserialize_u32 => visit_u32);
284 deserialize_type!(deserialize_u64 => visit_u64);
285 deserialize_type!(deserialize_f32 => visit_f32);
286 deserialize_type!(deserialize_f64 => visit_f64);
287
288 fn deserialize_bool<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
289 if let XmlEvent::StartElement { .. } = *self.peek()? {
290 self.set_map_value()
291 }
292 self.read_inner_value::<V, V::Value, _>(|this| {
293 if let XmlEvent::EndElement { .. } = *this.peek()? {
294 return visitor.visit_bool(false);
295 }
296 expect!(this.next()?, XmlEvent::Characters(s) => {
297 match s.as_str() {
298 "true" | "1" => visitor.visit_bool(true),
299 "false" | "0" => visitor.visit_bool(false),
300 _ => Err(de::Error::invalid_value(Unexpected::Str(&s), &"a boolean")),
301 }
302
303 })
304 })
305 }
306
307 fn deserialize_char<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
308 self.deserialize_string(visitor)
309 }
310
311 fn deserialize_str<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
312 self.deserialize_string(visitor)
313 }
314
315 fn deserialize_bytes<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
316 self.deserialize_string(visitor)
317 }
318
319 fn deserialize_byte_buf<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
320 self.deserialize_string(visitor)
321 }
322
323 fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
324 if let XmlEvent::StartElement { .. } = *self.peek()? {
325 self.set_map_value()
326 }
327 self.read_inner_value::<V, V::Value, _>(
328 |this| expect!(this.peek()?, &XmlEvent::EndElement { .. } => visitor.visit_unit()),
329 )
330 }
331
332 fn deserialize_unit_struct<V: de::Visitor<'de>>(
333 self,
334 _name: &'static str,
335 visitor: V,
336 ) -> Result<V::Value> {
337 self.deserialize_unit(visitor)
338 }
339
340 fn deserialize_newtype_struct<V: de::Visitor<'de>>(
341 self,
342 _name: &'static str,
343 visitor: V,
344 ) -> Result<V::Value> {
345 visitor.visit_newtype_struct(self)
346 }
347
348 fn deserialize_tuple_struct<V: de::Visitor<'de>>(
349 self,
350 _name: &'static str,
351 len: usize,
352 visitor: V,
353 ) -> Result<V::Value> {
354 self.deserialize_tuple(len, visitor)
355 }
356
357 fn deserialize_tuple<V: de::Visitor<'de>>(self, len: usize, visitor: V) -> Result<V::Value> {
358 let child_deserializer = self.child();
359
360 visitor.visit_seq(SeqAccess::new(child_deserializer, Some(len)))
361 }
362
363 fn deserialize_enum<V: de::Visitor<'de>>(
364 self,
365 _name: &'static str,
366 _variants: &'static [&'static str],
367 visitor: V,
368 ) -> Result<V::Value> {
369 self.read_inner_value::<V, V::Value, _>(|this| visitor.visit_enum(EnumAccess::new(this)))
370 }
371
372 fn deserialize_string<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
373 if let XmlEvent::StartElement { .. } = *self.peek()? {
374 self.set_map_value()
375 }
376 self.read_inner_value::<V, V::Value, _>(|this| {
377 if let XmlEvent::EndElement { .. } = *this.peek()? {
378 return visitor.visit_str("");
379 }
380 expect!(this.next()?, XmlEvent::Characters(s) => {
381 visitor.visit_string(s)
382 })
383 })
384 }
385
386 fn deserialize_seq<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
387 let child_deserializer = self.child();
388
389 visitor.visit_seq(SeqAccess::new(child_deserializer, None))
390 }
391
392 fn deserialize_map<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
393 self.unset_map_value();
394 expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
395 let map_value = visitor.visit_map(MapAccess::new(self, attributes, false))?;
396 self.expect_end_element(name)?;
397 Ok(map_value)
398 })
399 }
400
401 fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
402 match *self.peek()? {
403 XmlEvent::EndElement { .. } => visitor.visit_none(),
404 _ => visitor.visit_some(self),
405 }
406 }
407
408 fn deserialize_ignored_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
409 self.unset_map_value();
410 let depth = self.depth;
411 loop {
412 self.next()?;
413 if self.depth == depth {
414 break;
415 }
416 }
417 visitor.visit_unit()
418 }
419
420 fn deserialize_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
421 match *self.peek()? {
422 XmlEvent::StartElement { .. } => self.deserialize_map(visitor),
423 XmlEvent::EndElement { .. } => self.deserialize_unit(visitor),
424 _ => self.deserialize_string(visitor),
425 }
426 }
427}