xmlity_quick_xml/
de.rs

1use std::ops::Deref;
2
3use quick_xml::{
4    events::{attributes::Attribute, BytesStart, Event},
5    name::QName as QuickName,
6    NsReader,
7};
8
9use xmlity::{
10    de::{self, Error as _, Unexpected, Visitor},
11    Deserialize, ExpandedName, LocalName, QName, XmlNamespace,
12};
13
14use crate::{xml_namespace_from_resolve_result, HasQuickXmlAlternative, OwnedQuickName};
15
16use super::Error;
17
18pub fn from_str<'a, T>(s: &'a str) -> Result<T, Error>
19where
20    T: Deserialize<'a>,
21{
22    let mut deserializer = Deserializer::from(s.as_bytes());
23    T::deserialize(&mut deserializer)
24}
25
26pub enum Peeked<'a> {
27    None,
28    Text,
29    CData,
30    Element {
31        name: QName<'a>,
32        namespace: Option<XmlNamespace<'a>>,
33    },
34}
35
36#[derive(Debug, Clone)]
37pub struct Deserializer<'i> {
38    reader: NsReader<&'i [u8]>,
39    current_depth: i16,
40    peeked_event: Option<Event<'i>>,
41}
42
43impl<'i> From<NsReader<&'i [u8]>> for Deserializer<'i> {
44    fn from(reader: NsReader<&'i [u8]>) -> Self {
45        Self::new(reader)
46    }
47}
48
49impl<'i> From<&'i [u8]> for Deserializer<'i> {
50    fn from(buffer: &'i [u8]) -> Self {
51        Self::new(NsReader::from_reader(buffer))
52    }
53}
54
55impl<'i> Deserializer<'i> {
56    pub fn new(reader: NsReader<&'i [u8]>) -> Self {
57        let mut new = Self {
58            reader,
59            current_depth: 0,
60            peeked_event: None,
61        };
62
63        if let Some(Event::Decl(_)) = new.peek_event() {
64            let Some(Event::Decl(_decl)) = new.next_event() else {
65                unreachable!("peek_event returned Decl but next_event did not")
66            };
67        }
68
69        new
70    }
71
72    fn read_event(&mut self) -> Result<quick_xml::events::Event<'i>, Error> {
73        while let Ok(event) = self.reader.read_event() {
74            match event {
75                Event::Text(text) if text.clone().into_inner().trim_ascii().is_empty() => {
76                    continue;
77                }
78
79                event => return Ok(event),
80            }
81        }
82
83        Ok(quick_xml::events::Event::Eof)
84    }
85
86    fn read_until_element_end(&mut self, name: &QuickName, depth: i16) -> Result<(), Error> {
87        while let Some(event) = self.peek_event() {
88            let correct_name = match event {
89                Event::End(ref e) if e.name() == *name => true,
90                Event::Eof => return Err(Error::Unexpected(Unexpected::Eof)),
91                _ => false,
92            };
93
94            if correct_name && self.current_depth == depth {
95                return Ok(());
96            }
97
98            self.next_event();
99        }
100
101        Err(Error::Unexpected(de::Unexpected::Eof))
102    }
103
104    pub fn peek_event(&mut self) -> Option<&quick_xml::events::Event<'i>> {
105        if self.peeked_event.is_some() {
106            return self.peeked_event.as_ref();
107        }
108
109        self.peeked_event = self.read_event().ok();
110        self.peeked_event.as_ref()
111    }
112
113    pub fn next_event(&mut self) -> Option<quick_xml::events::Event<'i>> {
114        let event = if self.peeked_event.is_some() {
115            self.peeked_event.take()
116        } else {
117            self.read_event().ok()
118        };
119
120        if matches!(event, Some(Event::End(_))) {
121            self.current_depth -= 1;
122        }
123        if matches!(event, Some(Event::Start(_))) {
124            self.current_depth += 1;
125        }
126
127        event
128    }
129
130    pub fn create_sub_seq_access<'p>(&'p mut self) -> SubSeqAccess<'p, 'i> {
131        SubSeqAccess::Filled {
132            current: Some(self.clone()),
133            parent: self,
134        }
135    }
136
137    pub fn try_deserialize<T, E>(
138        &mut self,
139        closure: impl for<'a> FnOnce(&'a mut Deserializer<'i>) -> Result<T, E>,
140    ) -> Result<T, E> {
141        let mut sub_deserializer = self.clone();
142        let res = closure(&mut sub_deserializer);
143
144        if res.is_ok() {
145            *self = sub_deserializer;
146        }
147        res
148    }
149
150    pub fn expand_name<'a>(&self, qname: QuickName<'a>) -> ExpandedName<'a> {
151        let (resolve_result, _) = self.reader.resolve(qname, false);
152        let namespace = xml_namespace_from_resolve_result(resolve_result).map(|ns| ns.into_owned());
153
154        ExpandedName::new(LocalName::from_quick_xml(qname.local_name()), namespace)
155    }
156
157    pub fn resolve_bytes_start<'a>(&self, bytes_start: &'a BytesStart<'a>) -> ExpandedName<'a> {
158        self.expand_name(bytes_start.name())
159    }
160
161    pub fn resolve_attribute<'a>(&self, attribute: &'a Attribute<'a>) -> ExpandedName<'a> {
162        self.expand_name(attribute.key)
163    }
164}
165
166pub struct ElementAccess<'a, 'r> {
167    deserializer: Option<&'a mut Deserializer<'r>>,
168    attribute_index: usize,
169    bytes_start: BytesStart<'r>,
170    start_depth: i16,
171    empty: bool,
172}
173
174impl Drop for ElementAccess<'_, '_> {
175    fn drop(&mut self) {
176        self.try_end().ok();
177    }
178}
179
180impl<'r> ElementAccess<'_, 'r> {
181    fn deserializer(&self) -> &Deserializer<'r> {
182        self.deserializer
183            .as_ref()
184            .expect("Should not be called after ElementAccess has been consumed")
185    }
186
187    fn try_end(&mut self) -> Result<(), Error> {
188        if self.empty {
189            return Ok(());
190        }
191
192        if let Some(deserializer) = self.deserializer.as_mut() {
193            deserializer.read_until_element_end(&self.bytes_start.name(), self.start_depth)?;
194        }
195        Ok(())
196    }
197}
198
199pub struct AttributeAccess<'a> {
200    name: ExpandedName<'a>,
201    value: String,
202}
203
204impl<'a> de::AttributeAccess<'a> for AttributeAccess<'a> {
205    type Error = Error;
206
207    fn name(&self) -> ExpandedName<'_> {
208        self.name.clone()
209    }
210
211    fn value(&self) -> &str {
212        self.value.as_str()
213    }
214}
215
216struct EmptySeqAccess;
217
218impl<'de> de::SeqAccess<'de> for EmptySeqAccess {
219    type Error = Error;
220    type SubAccess<'s>
221        = EmptySeqAccess
222    where
223        Self: 's;
224
225    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
226    where
227        T: Deserialize<'de>,
228    {
229        Ok(None)
230    }
231
232    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
233    where
234        T: Deserialize<'de>,
235    {
236        Ok(None)
237    }
238
239    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
240        Ok(EmptySeqAccess)
241    }
242}
243
244struct AttributeDeserializer<'a> {
245    name: ExpandedName<'a>,
246    value: String,
247}
248
249impl<'a> xmlity::Deserializer<'a> for AttributeDeserializer<'a> {
250    type Error = Error;
251
252    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
253    where
254        V: Visitor<'a>,
255    {
256        visitor.visit_attribute(AttributeAccess {
257            name: self.name,
258            value: self.value,
259        })
260    }
261
262    fn deserialize_seq<V>(self, _: V) -> Result<V::Value, Self::Error>
263    where
264        V: Visitor<'a>,
265    {
266        Err(Self::Error::Unexpected(de::Unexpected::Seq))
267    }
268}
269
270pub struct SubAttributesAccess<'a, 'r> {
271    deserializer: &'a Deserializer<'r>,
272    bytes_start: &'a BytesStart<'r>,
273    attribute_index: usize,
274    write_attribute_to: &'a mut usize,
275}
276
277impl Drop for SubAttributesAccess<'_, '_> {
278    fn drop(&mut self) {
279        *self.write_attribute_to = self.attribute_index;
280    }
281}
282
283fn next_attribute<'a, 'de, T: Deserialize<'de>>(
284    deserializer: &'a Deserializer<'_>,
285    bytes_start: &'a BytesStart<'_>,
286    attribute_index: &'a mut usize,
287) -> Result<Option<T>, Error> {
288    let (attribute, key) = loop {
289        let Some(attribute) = bytes_start.attributes().nth(*attribute_index) else {
290            return Ok(None);
291        };
292
293        let attribute = attribute?;
294
295        let key = deserializer.resolve_attribute(&attribute).into_owned();
296
297        const XMLNS_NAMESPACE: XmlNamespace<'static> =
298            XmlNamespace::new_dangerous("http://www.w3.org/2000/xmlns/");
299
300        if key.namespace() == Some(&XMLNS_NAMESPACE) {
301            *attribute_index += 1;
302            continue;
303        }
304
305        break (attribute, key);
306    };
307
308    let value = String::from_utf8(attribute.value.into_owned())
309        .expect("attribute value should be valid utf8");
310
311    let deserializer = AttributeDeserializer { name: key, value };
312
313    let res = T::deserialize(deserializer)?;
314
315    // Only increment the index if the deserialization was successful
316    *attribute_index += 1;
317
318    Ok(Some(res))
319}
320
321impl<'de> de::AttributesAccess<'de> for SubAttributesAccess<'_, 'de> {
322    type Error = Error;
323
324    type SubAccess<'a>
325        = SubAttributesAccess<'a, 'de>
326    where
327        Self: 'a;
328
329    fn next_attribute<T>(&mut self) -> Result<Option<T>, Self::Error>
330    where
331        T: Deserialize<'de>,
332    {
333        next_attribute(
334            self.deserializer,
335            self.bytes_start,
336            &mut self.attribute_index,
337        )
338    }
339
340    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
341        Ok(Self::SubAccess {
342            deserializer: self.deserializer,
343            bytes_start: self.bytes_start,
344            attribute_index: self.attribute_index + 1,
345            write_attribute_to: self.write_attribute_to,
346        })
347    }
348}
349
350impl<'de> de::AttributesAccess<'de> for ElementAccess<'_, 'de> {
351    type Error = Error;
352
353    type SubAccess<'a>
354        = SubAttributesAccess<'a, 'de>
355    where
356        Self: 'a;
357
358    fn next_attribute<T>(&mut self) -> Result<Option<T>, Self::Error>
359    where
360        T: Deserialize<'de>,
361    {
362        next_attribute(
363            self.deserializer
364                .as_ref()
365                .expect("deserializer should be set"),
366            &self.bytes_start,
367            &mut self.attribute_index,
368        )
369    }
370
371    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
372        Ok(Self::SubAccess {
373            bytes_start: &self.bytes_start,
374            attribute_index: self.attribute_index,
375            write_attribute_to: &mut self.attribute_index,
376            deserializer: self
377                .deserializer
378                .as_ref()
379                .expect("Should not be called after ElementAccess has been consumed"),
380        })
381    }
382}
383
384impl<'a, 'de> de::ElementAccess<'de> for ElementAccess<'a, 'de> {
385    type ChildrenAccess = ChildrenAccess<'a, 'de>;
386
387    fn name(&self) -> ExpandedName<'_> {
388        self.deserializer().resolve_bytes_start(&self.bytes_start)
389    }
390
391    fn children(mut self) -> Result<Self::ChildrenAccess, Self::Error> {
392        Ok(if self.empty {
393            ChildrenAccess::Empty
394        } else {
395            let deserializer = self
396                .deserializer
397                .take()
398                .expect("Should not be called after ElementAccess has been consumed");
399
400            ChildrenAccess::Filled {
401                expected_end: QName::from_quick_xml(self.bytes_start.name()).into_owned(),
402                start_depth: self.start_depth,
403                deserializer,
404            }
405        })
406    }
407}
408
409pub enum ChildrenAccess<'a, 'r> {
410    Filled {
411        expected_end: QName<'static>,
412        deserializer: &'a mut Deserializer<'r>,
413        start_depth: i16,
414    },
415    Empty,
416}
417
418impl Drop for ChildrenAccess<'_, '_> {
419    fn drop(&mut self) {
420        let ChildrenAccess::Filled {
421            expected_end,
422            deserializer,
423            start_depth,
424        } = self
425        else {
426            return;
427        };
428
429        deserializer
430            .read_until_element_end(&OwnedQuickName::new(expected_end).as_ref(), *start_depth)
431            .unwrap();
432    }
433}
434
435impl<'r> de::SeqAccess<'r> for ChildrenAccess<'_, 'r> {
436    type Error = Error;
437
438    type SubAccess<'s>
439        = SubSeqAccess<'s, 'r>
440    where
441        Self: 's;
442
443    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
444    where
445        T: Deserialize<'r>,
446    {
447        let ChildrenAccess::Filled {
448            expected_end,
449            deserializer,
450            start_depth,
451        } = self
452        else {
453            return Ok(None);
454        };
455
456        let current_depth = deserializer.current_depth;
457
458        if let Some(quick_xml::events::Event::End(bytes_end)) = deserializer.peek_event() {
459            if OwnedQuickName::new(expected_end).as_ref() != bytes_end.name()
460                && current_depth == *start_depth
461            {
462                return Err(Error::custom(format!(
463                    "Expected end of element {}, found end of element {}",
464                    expected_end,
465                    QName::from_quick_xml(bytes_end.name())
466                )));
467            }
468
469            return Ok(None);
470        }
471
472        deserializer
473            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
474            .map(Some)
475    }
476
477    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
478    where
479        T: Deserialize<'r>,
480    {
481        let ChildrenAccess::Filled {
482            expected_end,
483            deserializer,
484            start_depth,
485        } = self
486        else {
487            return Ok(None);
488        };
489
490        let current_depth = deserializer.current_depth;
491
492        if let Some(quick_xml::events::Event::End(bytes_end)) = deserializer.peek_event() {
493            if OwnedQuickName::new(expected_end).as_ref() != bytes_end.name()
494                && current_depth == *start_depth
495            {
496                return Err(Error::custom(format!(
497                    "Expected end of element {}, found end of element {}",
498                    expected_end,
499                    QName::from_quick_xml(bytes_end.name())
500                )));
501            }
502
503            return Ok(None);
504        }
505
506        deserializer
507            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
508            .map(Some)
509    }
510
511    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
512        let ChildrenAccess::Filled { deserializer, .. } = self else {
513            return Ok(SubSeqAccess::Empty);
514        };
515
516        Ok(deserializer.create_sub_seq_access())
517    }
518}
519
520pub struct SeqAccess<'a, 'r> {
521    deserializer: &'a mut Deserializer<'r>,
522}
523
524#[allow(clippy::large_enum_variant)]
525pub enum SubSeqAccess<'p, 'r> {
526    Filled {
527        current: Option<Deserializer<'r>>,
528        parent: &'p mut Deserializer<'r>,
529    },
530    Empty,
531}
532
533impl Drop for SubSeqAccess<'_, '_> {
534    fn drop(&mut self) {
535        if let SubSeqAccess::Filled { current, parent } = self {
536            **parent = current.take().expect("SubSeqAccess dropped twice");
537        }
538    }
539}
540
541impl<'r> de::SeqAccess<'r> for SubSeqAccess<'_, 'r> {
542    type Error = Error;
543
544    type SubAccess<'s>
545        = SubSeqAccess<'s, 'r>
546    where
547        Self: 's;
548
549    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
550    where
551        T: Deserialize<'r>,
552    {
553        let Self::Filled { current, .. } = self else {
554            return Ok(None);
555        };
556
557        current
558            .as_mut()
559            .expect("SubSeqAccess used after drop")
560            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
561            .map(Some)
562    }
563
564    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
565    where
566        T: Deserialize<'r>,
567    {
568        let Self::Filled { current, .. } = self else {
569            return Ok(None);
570        };
571        current
572            .as_mut()
573            .expect("SubSeqAccess used after drop")
574            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
575            .map(Some)
576    }
577
578    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
579        let Self::Filled { current, .. } = self else {
580            return Ok(SubSeqAccess::Empty);
581        };
582
583        Ok(current
584            .as_mut()
585            .expect("SubSeqAccess used after drop")
586            .create_sub_seq_access())
587    }
588}
589
590impl<'r> de::SeqAccess<'r> for SeqAccess<'_, 'r> {
591    type Error = Error;
592
593    type SubAccess<'s>
594        = SubSeqAccess<'s, 'r>
595    where
596        Self: 's;
597
598    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
599    where
600        T: Deserialize<'r>,
601    {
602        self.deserializer
603            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
604            .map(Some)
605    }
606
607    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
608    where
609        T: Deserialize<'r>,
610    {
611        self.deserializer
612            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
613            .map(Some)
614    }
615
616    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
617        Ok(SubSeqAccess::Filled {
618            current: Some(self.deserializer.clone()),
619            parent: self.deserializer,
620        })
621    }
622}
623
624impl<'r> xmlity::Deserializer<'r> for &mut Deserializer<'r> {
625    type Error = Error;
626
627    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
628    where
629        V: de::Visitor<'r>,
630    {
631        let event = self.next_event().ok_or_else(|| Error::custom("EOF"))?;
632
633        match event {
634            quick_xml::events::Event::Start(bytes_start) => {
635                let bytes_start_name = bytes_start.name().0.to_owned();
636
637                let value = Visitor::visit_element(
638                    visitor,
639                    ElementAccess {
640                        bytes_start,
641                        start_depth: self.current_depth,
642                        deserializer: Some(self),
643                        empty: false,
644                        attribute_index: 0,
645                    },
646                )?;
647
648                let end_event = self.next_event().ok_or_else(|| Error::custom("EOF"))?;
649
650                let mut success = false;
651
652                if let quick_xml::events::Event::End(bytes_end) = &end_event {
653                    if bytes_end.name() == QuickName(&bytes_start_name) {
654                        success = true;
655                    }
656                }
657
658                if success {
659                    Ok(value)
660                } else {
661                    Err(Error::custom("No matching end element"))
662                }
663            }
664            quick_xml::events::Event::End(_bytes_end) => {
665                Err(Error::custom("Unexpected end element"))
666            }
667            quick_xml::events::Event::Empty(bytes_start) => visitor.visit_element(ElementAccess {
668                bytes_start: bytes_start.into_owned().clone(),
669                start_depth: self.current_depth,
670                deserializer: Some(self),
671                empty: true,
672                attribute_index: 0,
673            }),
674            quick_xml::events::Event::Text(bytes_text) => visitor.visit_text(bytes_text.deref()),
675            quick_xml::events::Event::CData(bytes_cdata) => {
676                visitor.visit_cdata(bytes_cdata.deref())
677            }
678            quick_xml::events::Event::Comment(bytes_text) => {
679                visitor.visit_comment(bytes_text.deref())
680            }
681            quick_xml::events::Event::Decl(bytes_decl) => visitor.visit_decl(
682                bytes_decl.version()?,
683                match bytes_decl.encoding() {
684                    Some(Ok(encoding)) => Some(encoding),
685                    Some(Err(err)) => return Err(Error::QuickXml(err.into())),
686                    None => None,
687                },
688                match bytes_decl.standalone() {
689                    Some(Ok(standalone)) => Some(standalone),
690                    Some(Err(err)) => return Err(Error::QuickXml(err.into())),
691                    None => None,
692                },
693            ),
694            quick_xml::events::Event::PI(bytes_pi) => visitor.visit_pi(bytes_pi.deref()),
695            quick_xml::events::Event::DocType(bytes_text) => {
696                visitor.visit_doctype(bytes_text.deref())
697            }
698            quick_xml::events::Event::Eof => Err(Error::custom("Unexpected EOF")),
699        }
700    }
701
702    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
703    where
704        V: de::Visitor<'r>,
705    {
706        visitor.visit_seq(SeqAccess { deserializer: self })
707    }
708}
709
710impl<'r> xmlity::Deserializer<'r> for Deserializer<'r> {
711    type Error = Error;
712
713    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
714    where
715        V: de::Visitor<'r>,
716    {
717        (&mut self).deserialize_any(visitor)
718    }
719
720    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
721    where
722        V: de::Visitor<'r>,
723    {
724        (&mut self).deserialize_seq(visitor)
725    }
726}